Class EnsembleTeacherModel<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Teachers
- Assembly
- AiDotNet.dll
Ensemble teacher model that combines predictions from multiple teacher models.
public class EnsembleTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
EnsembleTeacherModel<T>
- Implements
- Inherited Members
Remarks
For Beginners: Ensemble learning combines multiple models to create a stronger, more robust teacher. The intuition is similar to seeking advice from multiple experts rather than relying on a single expert.
Benefits of Ensemble Teachers: - **Higher Accuracy**: Ensemble outperforms individual models - **Better Calibration**: Averaging reduces overconfidence - **Robustness**: Less sensitive to individual model biases - **Knowledge Diversity**: Student learns from complementary perspectives
Common Ensemble Strategies: - **Uniform Average**: Equal weight to all teachers (default) - **Weighted Average**: More weight to better-performing teachers - **Voting**: For classification, majority vote - **Stacking**: Meta-model combines predictions
Real-world Analogy: Imagine learning to play chess from multiple grandmasters. Each has different playing styles and strategies. By learning from all of them, you develop a more well-rounded understanding of the game than you would from just one teacher.
Practical Example: Train 3-5 models with different: - Initializations (different random seeds) - Architectures (CNN, ResNet, Transformer) - Hyperparameters (learning rates, depths) Combine them to create a powerful ensemble teacher.
References: - You et al. (2017). Learning from Multiple Teacher Networks. KDD. - Fukuda et al. (2017). Efficient Knowledge Distillation from an Ensemble of Teachers.
Constructors
EnsembleTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[], double[]?, EnsembleAggregationMode)
Initializes a new instance of the EnsembleTeacherModel class.
public EnsembleTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[] teachers, double[]? weights = null, EnsembleAggregationMode aggregationMode = EnsembleAggregationMode.WeightedAverage)
Parameters
teachersITeacherModel<Vector<T>, Vector<T>>[]Array of teacher models to ensemble.
weightsdouble[]Optional weights for each teacher (default: uniform). Must sum to 1.0.
aggregationModeEnsembleAggregationModeHow to combine teacher predictions (default: WeightedAverage).
Remarks
For Beginners: Create an ensemble by providing multiple trained teacher models. If weights are not specified, all teachers contribute equally.
Example:
var teacher1 = new TeacherModelWrapper<double>(model1);
var teacher2 = new TeacherModelWrapper<double>(model2);
var teacher3 = new TeacherModelWrapper<double>(model3);
// Uniform ensemble (equal weights)
var ensemble = new EnsembleTeacherModel<double>(
new[] { teacher1, teacher2, teacher3 }
);
// Weighted ensemble (based on validation accuracy)
var ensemble2 = new EnsembleTeacherModel<double>(
teachers: new[] { teacher1, teacher2, teacher3 },
weights: new[] { 0.5, 0.3, 0.2 } // Best model gets 50% weight
);
Choosing Weights: - **Uniform**: Use when teachers perform similarly - **Validation-based**: Weight by validation accuracy - **Confidence-based**: Weight by prediction confidence - **Diversity-based**: Weight to maximize diversity
Properties
OutputDimension
Gets the output dimension (same for all teachers).
public override int OutputDimension { get; }
Property Value
SupportsJitCompilation
Gets whether this teacher supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
Returns
trueif WeightedAverage mode is used and all teachers support JIT compilation; otherwise,false.
Remarks
For Beginners: Ensemble JIT compilation is supported when: 1. WeightedAverage aggregation mode is used (other modes have dynamic operations) 2. All component teachers implement IJitCompilable and support JIT
The ensemble computation graph combines each teacher's graph with weighted addition.
TeacherCount
Gets the number of teachers in the ensemble.
public int TeacherCount { get; }
Property Value
Methods
ExportComputationGraph(List<ComputationNode<T>>)
Exports the ensemble computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the weighted ensemble output.
Remarks
The ensemble graph combines each teacher's computation graph using weighted addition: output = w1 * teacher1_output + w2 * teacher2_output + ... + wN * teacherN_output
For Beginners: This creates a combined computation graph that: 1. Creates separate computation paths for each teacher 2. Multiplies each teacher's output by its weight 3. Sums all weighted outputs
Expected speedup: 2-4x for inference after JIT compilation.
Exceptions
- NotSupportedException
Thrown when the aggregation mode is not WeightedAverage or when any teacher does not support JIT.
GetLogits(Vector<T>)
Gets ensemble logits by combining predictions from all teachers.
public override Vector<T> GetLogits(Vector<T> input)
Parameters
inputVector<T>Input data.
Returns
- Vector<T>
Ensemble logits.
Remarks
For Beginners: This combines logits from all teachers according to the aggregation mode (usually weighted average).
Architecture Note: Returns raw ensemble logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.
UpdateWeights(Vector<T>[], Vector<T>[])
Updates teacher weights based on performance (for adaptive weighting).
public void UpdateWeights(Vector<T>[] validationInputs, Vector<T>[] validationLabels)
Parameters
validationInputsVector<T>[]Validation inputs for evaluation.
validationLabelsVector<T>[]Validation labels.
Remarks
For Advanced Users: Call this periodically to adjust weights based on each teacher's current performance. Better teachers get higher weights.