Table of Contents

Class DistributedTeacherModel<T>

Namespace
AiDotNet.KnowledgeDistillation.Teachers
Assembly
AiDotNet.dll

Distributed teacher model that aggregates predictions from multiple distributed workers.

public class DistributedTeacherModel<T> : TeacherModelBase<Vector<T>, Vector<T>, T>, ITeacherModel<Vector<T>, Vector<T>>, IJitCompilable<T>

Type Parameters

T
Inheritance
TeacherModelBase<Vector<T>, Vector<T>, T>
DistributedTeacherModel<T>
Implements
ITeacherModel<Vector<T>, Vector<T>>
Inherited Members

Constructors

DistributedTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[], AggregationMode)

public DistributedTeacherModel(ITeacherModel<Vector<T>, Vector<T>>[] workers, AggregationMode aggregation = AggregationMode.Average)

Parameters

workers ITeacherModel<Vector<T>, Vector<T>>[]
aggregation AggregationMode

Properties

OutputDimension

Gets the number of output dimensions (e.g., number of classes for classification).

public override int OutputDimension { get; }

Property Value

int

Remarks

For Implementers: Return the size of the output vector. For example, a 10-class classifier should return 10.

SupportsJitCompilation

Gets whether this teacher supports JIT compilation.

public override bool SupportsJitCompilation { get; }

Property Value

bool

Returns true if Average aggregation mode is used and all workers support JIT compilation; otherwise, false.

Remarks

Note: While "distributed" implies workers on different machines, JIT compilation is supported when all workers are local models that implement IJitCompilable. This enables combining their computation graphs for optimized inference.

Methods

ExportComputationGraph(List<ComputationNode<T>>)

Exports the distributed worker computation graph for JIT compilation.

public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the averaged worker output.

Remarks

The distributed graph combines each worker's computation graph using averaging: output = (worker1_output + worker2_output + ... + workerN_output) / N

Note: JIT compilation creates a single optimized computation graph combining all worker models. This is beneficial when workers are local models; for truly distributed inference across machines, use runtime aggregation instead.

Exceptions

NotSupportedException

Thrown when the aggregation mode is not Average or when any worker does not support JIT.

GetLogits(Vector<T>)

Gets aggregated logits from all distributed workers.

public override Vector<T> GetLogits(Vector<T> input)

Parameters

input Vector<T>

Input data.

Returns

Vector<T>

Aggregated logits from all workers.

Remarks

Architecture Note: Returns raw aggregated logits. Temperature scaling and softmax are handled by distillation strategies, not by the teacher model.