Table of Contents

Class OnlineTeacherModel<T>

Namespace
AiDotNet.KnowledgeDistillation.Teachers
Assembly
AiDotNet.dll

Online teacher model that updates its parameters during student training.

public class OnlineTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
TeacherModelBase<Vector<T>, Vector<T>, T>
OnlineTeacherModel<T>
Implements
ITeacherModel<Vector<T>, Vector<T>>
Inherited Members

Remarks

For Beginners: Unlike standard distillation where the teacher is frozen, online distillation allows the teacher to continue learning during student training. This is useful for: - Continuous learning scenarios - Evolving data distributions - Co-training teacher and student simultaneously

How It Works: 1. Initialize teacher model (can be pre-trained or random) 2. During student training, also update teacher with new data 3. Teacher provides evolving knowledge to student 4. Both models improve together

Real-world Analogy: Imagine a mentor and apprentice both continuing to learn as they work together. The mentor (teacher) doesn't just transfer old knowledge - they also learn from new experiences and share those insights with the apprentice (student).

Use Cases: - **Streaming Data**: New data arrives continuously - **Domain Adaptation**: Distribution shifts over time - **Co-training**: Teacher and student help each other - **Incremental Learning**: Models must adapt to new classes/tasks

Update Strategies: - **EMA (Exponential Moving Average)**: Smooth updates, stable teacher - **Periodic Sync**: Update teacher every N steps - **Gradient-based**: Teacher trained with separate loss - **Momentum**: Teacher follows student with momentum

Advantages: - Adapts to changing data - No need for pre-trained teacher - Can improve teacher and student together - Suitable for lifelong learning

Challenges: - Risk of teacher forgetting/degrading - Need careful update rate tuning - More complex training dynamics - Harder to debug

References: - Zhang et al. (2018). Deep Mutual Learning. CVPR. - Anil et al. (2018). Large Scale Distributed Neural Network Training through Online Distillation.

Constructors

OnlineTeacherModel(IJitCompilable<T>, int, int, Action<Vector<T>, Vector<T>>?, OnlineUpdateMode, double, int)

Initializes a new instance of the OnlineTeacherModel class using a JIT-compilable model.

public OnlineTeacherModel(IJitCompilable<T> jitCompilableModel, int inputDimension, int outputDimension, Action<Vector<T>, Vector<T>>? teacherUpdate = null, OnlineUpdateMode updateMode = OnlineUpdateMode.EMA, double updateRate = 0.999, int updateFrequency = 1)

Parameters

jitCompilableModel IJitCompilable<T>

A JIT-compilable model for forward pass.

inputDimension int

Input dimension of the teacher.

outputDimension int

Output dimension of the teacher.

teacherUpdate Action<Vector<T>, Vector<T>>

Optional function to update teacher parameters.

updateMode OnlineUpdateMode

How to update the teacher (default: EMA).

updateRate double

Update rate for EMA or learning rate (default: 0.999 for EMA).

updateFrequency int

How often to update (default: every step).

Remarks

JIT Support: This constructor enables JIT compilation for inference when the underlying model supports it. Note that updates still use the teacherUpdate function if provided.

OnlineTeacherModel(Func<Vector<T>, Vector<T>>, int, int, Action<Vector<T>, Vector<T>>?, OnlineUpdateMode, double, int)

Initializes a new instance of the OnlineTeacherModel class using function delegates.

public OnlineTeacherModel(Func<Vector<T>, Vector<T>> teacherForward, int inputDimension, int outputDimension, Action<Vector<T>, Vector<T>>? teacherUpdate = null, OnlineUpdateMode updateMode = OnlineUpdateMode.EMA, double updateRate = 0.999, int updateFrequency = 1)

Parameters

teacherForward Func<Vector<T>, Vector<T>>

Function to perform forward pass through teacher.

inputDimension int

Input dimension of the teacher.

outputDimension int

Output dimension of the teacher.

teacherUpdate Action<Vector<T>, Vector<T>>

Optional function to update teacher parameters (input, gradient).

updateMode OnlineUpdateMode

How to update the teacher (default: EMA).

updateRate double

Update rate for EMA or learning rate (default: 0.999 for EMA).

updateFrequency int

How often to update (default: every step).

Remarks

Note: This constructor creates a non-JIT-compilable teacher. For JIT support, use the constructor that accepts an IJitCompilable model.

Properties

IsUpdating

Gets or sets whether the teacher is currently updating.

public bool IsUpdating { get; set; }

Property Value

bool

OutputDimension

Gets the output dimension of the teacher model.

public override int OutputDimension { get; }

Property Value

int

SupportsJitCompilation

Gets whether this teacher supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

true if constructed with an IJitCompilable model that supports JIT compilation; false if constructed with function delegates which cannot be exported as a computation graph.

Methods

ExportComputationGraph(List<ComputationNode<T>>)

Exports the computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input nodes.

Returns

ComputationNode<T>

The output computation node.

Remarks

When constructed with an IJitCompilable model, this method delegates to the underlying model's computation graph export. When constructed with function delegates, JIT compilation is not supported because function delegates can contain arbitrary code that cannot be represented as tensor operations.

To enable JIT compilation, use the constructor that accepts an IJitCompilable model instead of using function delegates.

Exceptions

NotSupportedException

Thrown when using function delegates instead of an IJitCompilable model.

GetLogits(Vector<T>)

Gets logits from the teacher model.

public override Vector<T> GetLogits(Vector<T> input)

Parameters

input Vector<T>

Returns

Vector<T>

Remarks

Architecture Note: Returns raw logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.

PauseUpdates()

Pauses teacher updates (freezes teacher).

public void PauseUpdates()

ResetCounter()

Resets the update counter.

public void ResetCounter()

ResumeUpdates()

Resumes teacher updates.

public void ResumeUpdates()

Update(Vector<T>, Vector<T>)

Updates the teacher model with new data.

public void Update(Vector<T> input, Vector<T> targetOutput)

Parameters

input Vector<T>

Input that was used for prediction.

targetOutput Vector<T>

Target output for the teacher (can be ground truth or student prediction).

Remarks

For Beginners: Call this after each batch to update the teacher. The teacher learns from either: - Ground truth labels (teacher improves on task) - Student predictions (mutual learning - teacher learns from student too!)

Update modes: - **EMA**: Teacher smoothly tracks student, no explicit gradient - **GradientBased**: Teacher trained with standard gradient descent - **MomentumBased**: Teacher follows student with momentum