Class DistributedTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Distributed teacher model that aggregates predictions from multiple distributed workers.
public class DistributedTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
T
- Inheritance
-
DistributedTeacherModel<T>
- Implements
- Inherited Members
Constructors
DistributedTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[], AggregationMode)
public DistributedTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[] workers, AggregationMode aggregation = AggregationMode.Average)
Parameters
workersITeacherModel<Vector<T>, Vector<T>>[]aggregationAggregationMode
Properties
OutputDimension
Gets the number of output dimensions (e.g., number of classes for classification).
public override int OutputDimension { get; }
Property Value
Remarks
For Implementers: Return the size of the output vector. For example, a 10-class classifier should return 10.
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
Returns
trueif Average aggregation mode is used and all workers support JIT compilation; otherwise,false.
Remarks
Note: While "distributed" implies workers on different machines, JIT compilation is supported when all workers are local models that implement IJitCompilable. This enables combining their computation graphs for optimized inference.
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the distributed worker computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the averaged worker output.
Remarks
The distributed graph combines each worker's computation graph using averaging: output = (worker1_output + worker2_output + ... + workerN_output) / N
Note: JIT compilation creates a single optimized computation graph combining all worker models. This is beneficial when workers are local models; for truly distributed inference across machines, use runtime aggregation instead.
Exceptions
- NotSupportedException
Thrown when the aggregation mode is not Average or when any worker does not support JIT.
GetLogits(Vector<T>)
Gets aggregated logits from all distributed workers.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>Input data.
Returns
- Vector<T>
Aggregated logits from all workers.
Remarks
Architecture Note: Returns raw aggregated logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.