Table of Contents

Class EnsembleTeacherModel<T>

Namespace
AiDotNet.KnowledgeDistillation.Teachers
Assembly
AiDotNet.dll

Ensemble teacher model that combines predictions from multiple teacher models.

public class EnsembleTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
TeacherModelBase<Vector<T>, Vector<T>, T>
EnsembleTeacherModel<T>
Implements
ITeacherModel<Vector<T>, Vector<T>>
Inherited Members

Remarks

For Beginners: Ensemble learning combines multiple models to create a stronger, more robust teacher. The intuition is similar to seeking advice from multiple experts rather than relying on a single expert.

Benefits of Ensemble Teachers: - **Higher Accuracy**: Ensemble outperforms individual models - **Better Calibration**: Averaging reduces overconfidence - **Robustness**: Less sensitive to individual model biases - **Knowledge Diversity**: Student learns from complementary perspectives

Common Ensemble Strategies: - **Uniform Average**: Equal weight to all teachers (default) - **Weighted Average**: More weight to better-performing teachers - **Voting**: For classification, majority vote - **Stacking**: Meta-model combines predictions

Real-world Analogy: Imagine learning to play chess from multiple grandmasters. Each has different playing styles and strategies. By learning from all of them, you develop a more well-rounded understanding of the game than you would from just one teacher.

Practical Example: Train 3-5 models with different: - Initializations (different random seeds) - Architectures (CNN, ResNet, Transformer) - Hyperparameters (learning rates, depths) Combine them to create a powerful ensemble teacher.

References: - You et al. (2017). Learning from Multiple Teacher Networks. KDD. - Fukuda et al. (2017). Efficient Knowledge Distillation from an Ensemble of Teachers.

Constructors

EnsembleTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[], double[]?, EnsembleAggregationMode)

Initializes a new instance of the EnsembleTeacherModel class.

public EnsembleTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[] teachers, double[]? weights = null, EnsembleAggregationMode aggregationMode = EnsembleAggregationMode.WeightedAverage)

Parameters

teachers ITeacherModel<Vector<T>, Vector<T>>[]

Array of teacher models to ensemble.

weights double[]

Optional weights for each teacher (default: uniform). Must sum to 1.0.

aggregationMode EnsembleAggregationMode

How to combine teacher predictions (default: WeightedAverage).

Remarks

For Beginners: Create an ensemble by providing multiple trained teacher models. If weights are not specified, all teachers contribute equally.

Example:

var teacher1 = new TeacherModelWrapper<double>(model1);
var teacher2 = new TeacherModelWrapper<double>(model2);
var teacher3 = new TeacherModelWrapper<double>(model3);

// Uniform ensemble (equal weights) var ensemble = new EnsembleTeacherModel<double>( new[] { teacher1, teacher2, teacher3 } );

// Weighted ensemble (based on validation accuracy) var ensemble2 = new EnsembleTeacherModel<double>( teachers: new[] { teacher1, teacher2, teacher3 }, weights: new[] { 0.5, 0.3, 0.2 } // Best model gets 50% weight );

Choosing Weights: - **Uniform**: Use when teachers perform similarly - **Validation-based**: Weight by validation accuracy - **Confidence-based**: Weight by prediction confidence - **Diversity-based**: Weight to maximize diversity

Properties

OutputDimension

Gets the output dimension (same for all teachers).

public override int OutputDimension { get; }

Property Value

int

SupportsJitCompilation

Gets whether this teacher supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

Returns true if WeightedAverage mode is used and all teachers support JIT compilation; otherwise, false.

Remarks

For Beginners: Ensemble JIT compilation is supported when: 1. WeightedAverage aggregation mode is used (other modes have dynamic operations) 2. All component teachers implement IJitCompilable and support JIT

The ensemble computation graph combines each teacher's graph with weighted addition.

TeacherCount

Gets the number of teachers in the ensemble.

public int TeacherCount { get; }

Property Value

int

Methods

ExportComputationGraph(List<ComputationNode<T>>)

Exports the ensemble computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the weighted ensemble output.

Remarks

The ensemble graph combines each teacher's computation graph using weighted addition: output = w1 * teacher1_output + w2 * teacher2_output + ... + wN * teacherN_output

For Beginners: This creates a combined computation graph that: 1. Creates separate computation paths for each teacher 2. Multiplies each teacher's output by its weight 3. Sums all weighted outputs

Expected speedup: 2-4x for inference after JIT compilation.

Exceptions

NotSupportedException

Thrown when the aggregation mode is not WeightedAverage or when any teacher does not support JIT.

GetLogits(Vector<T>)

Gets ensemble logits by combining predictions from all teachers.

public override Vector<T> GetLogits(Vector<T> input)

Parameters

input Vector<T>

Input data.

Returns

Vector<T>

Ensemble logits.

Remarks

For Beginners: This combines logits from all teachers according to the aggregation mode (usually weighted average).

Architecture Note: Returns raw ensemble logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.

UpdateWeights(Vector<T>[], Vector<T>[])

Updates teacher weights based on performance (for adaptive weighting).

public void UpdateWeights(Vector<T>[] validationInputs, Vector<T>[] validationLabels)

Parameters

validationInputs Vector<T>[]

Validation inputs for evaluation.

validationLabels Vector<T>[]

Validation labels.

Remarks

For Advanced Users: Call this periodically to adjust weights based on each teacher's current performance. Better teachers get higher weights.