Table of Contents

Class SimilarityPreservingStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Similarity-preserving distillation that preserves pairwise similarity structure.

public class SimilarityPreservingStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T
Inheritance
SimilarityPreservingStrategy<T>
Implements
Inherited Members

Constructors

SimilarityPreservingStrategy(double, double, double)

public SimilarityPreservingStrategy(double similarityWeight = 0.5, double temperature = 3, double alpha = 0.3)

Parameters

similarityWeight double
temperature double
alpha double

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the distillation loss for backpropagation.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

Matrix<T>

The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]

Remarks

For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.

Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the distillation loss between student and teacher batch outputs.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

The student model's output logits for a batch. Shape: [batch_size x num_classes]

teacherBatchOutput Matrix<T>

The teacher model's output logits for a batch. Shape: [batch_size x num_classes]

trueLabelsBatch Matrix<T>

Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]

Returns

T

The computed distillation loss value (scalar) for the batch.

Remarks

For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.

Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.

ComputeSimilarityLoss(Vector<T>[], Vector<T>[])

public T ComputeSimilarityLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings)

Parameters

studentEmbeddings Vector<T>[]
teacherEmbeddings Vector<T>[]

Returns

T