Class OnlineTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Online teacher model that updates its parameters during student training.
public class OnlineTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
OnlineTeacherModel<T>
- Implements
- Inherited Members
Remarks
For Beginners: Unlike standard distillation where the teacher is frozen, online distillation allows the teacher to continue learning during student training. This is useful for: - Continuous learning scenarios - Evolving data distributions - Co-training teacher and student simultaneously
How It Works: 1. Initialize teacher model (can be pre-trained or random) 2. During student training, also update teacher with new data 3. Teacher provides evolving knowledge to student 4. Both models improve together
Real-world Analogy: Imagine a mentor and apprentice both continuing to learn as they work together. The mentor (teacher) doesn't just transfer old knowledge - they also learn from new experiences and share those insights with the apprentice (student).
Use Cases: - **Streaming Data**: New data arrives continuously - **Domain Adaptation**: Distribution shifts over time - **Co-training**: Teacher and student help each other - **Incremental Learning**: Models must adapt to new classes/tasks
Update Strategies: - **EMA (Exponential Moving Average)**: Smooth updates, stable teacher - **Periodic Sync**: Update teacher every N steps - **Gradient-based**: Teacher trained with separate loss - **Momentum**: Teacher follows student with momentum
Advantages: - Adapts to changing data - No need for pre-trained teacher - Can improve teacher and student together - Suitable for lifelong learning
Challenges: - Risk of teacher forgetting/degrading - Need careful update rate tuning - More complex training dynamics - Harder to debug
References: - Zhang et al. (2018). Deep Mutual Learning. CVPR. - Anil et al. (2018). Large Scale Distributed Neural Network Training through Online Distillation.
Constructors
OnlineTeacherModel(IJitCompilable<T>, int, int, Action<Vector<T>, Vector<T>>?, OnlineUpdateMode, double, int)
Initializes a new instance of the OnlineTeacherModel class using a JIT-compilable model.
public OnlineTeacherModel(IJitCompilable<T> jitCompilableModel, int inputDimension, int outputDimension, Action<Vector<T>, Vector<T>>? teacherUpdate = null, OnlineUpdateMode updateMode = OnlineUpdateMode.EMA, double updateRate = 0.999, int updateFrequency = 1)
Parameters
jitCompilableModelIJitCompilable<T>A JIT-compilable model for forward pass.
inputDimensionintInput dimension of the teacher.
outputDimensionintOutput dimension of the teacher.
teacherUpdateAction<Vector<T>, Vector<T>>Optional function to update teacher parameters.
updateModeOnlineUpdateModeHow to update the teacher (default: EMA).
updateRatedoubleUpdate rate for EMA or learning rate (default: 0.999 for EMA).
updateFrequencyintHow often to update (default: every step).
Remarks
JIT Support: This constructor enables JIT compilation for inference when the underlying model supports it. Note that updates still use the teacherUpdate function if provided.
OnlineTeacherModel(Func<Vector<T>, Vector<T>>, int, int, Action<Vector<T>, Vector<T>>?, OnlineUpdateMode, double, int)
Initializes a new instance of the OnlineTeacherModel class using function delegates.
public OnlineTeacherModel(Func<Vector<T>, Vector<T>> teacherForward, int inputDimension, int outputDimension, Action<Vector<T>, Vector<T>>? teacherUpdate = null, OnlineUpdateMode updateMode = OnlineUpdateMode.EMA, double updateRate = 0.999, int updateFrequency = 1)
Parameters
teacherForwardFunc<Vector<T>, Vector<T>>Function to perform forward pass through teacher.
inputDimensionintInput dimension of the teacher.
outputDimensionintOutput dimension of the teacher.
teacherUpdateAction<Vector<T>, Vector<T>>Optional function to update teacher parameters (input, gradient).
updateModeOnlineUpdateModeHow to update the teacher (default: EMA).
updateRatedoubleUpdate rate for EMA or learning rate (default: 0.999 for EMA).
updateFrequencyintHow often to update (default: every step).
Remarks
Note: This constructor creates a non-JIT-compilable teacher. For JIT support, use the constructor that accepts an IJitCompilable model.
Properties
IsUpdating
Gets or sets whether the teacher is currently updating.
public bool IsUpdating { get; set; }
Property Value
OutputDimension
Gets the output dimension of the teacher model.
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif constructed with an IJitCompilable model that supports JIT compilation;falseif constructed with function delegates which cannot be exported as a computation graph.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input nodes.
Returns
- ComputationNode<T>
The output computation node.
Remarks
When constructed with an IJitCompilable model, this method delegates to the underlying model's computation graph export. When constructed with function delegates, JIT compilation is not supported because function delegates can contain arbitrary code that cannot be represented as tensor operations.
To enable JIT compilation, use the constructor that accepts an IJitCompilable model instead of using function delegates.
Exceptions
- NotSupportedException
Thrown when using function delegates instead of an IJitCompilable model.
GetLogits(Vector<T>)
Gets logits from the teacher model.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
Remarks
Architecture Note: Returns raw logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.
PauseUpdates()
Pauses teacher updates (freezes teacher).
public void PauseUpdates()
ResetCounter()
Resets the update counter.
public void ResetCounter()
ResumeUpdates()
Resumes teacher updates.
public void ResumeUpdates()
Update(Vector<T>, Vector<T>)
Updates the teacher model with new data.
public void Update(Vector<T> input, Vector<T> targetOutput)
Parameters
inputVector<T>Input that was used for prediction.
targetOutputVector<T>Target output for the teacher (can be ground truth or student prediction).
Remarks
For Beginners: Call this after each batch to update the teacher. The teacher learns from either: - Ground truth labels (teacher improves on task) - Student predictions (mutual learning - teacher learns from student too!)
Update modes: - **EMA**: Teacher smoothly tracks student, no explicit gradient - **GradientBased**: Teacher trained with standard gradient descent - **MomentumBased**: Teacher follows student with momentum