Class SimilarityPreservingStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Similarity-preserving distillation that preserves pairwise similarity structure.
public class SimilarityPreservingStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
T
- Inheritance
-
SimilarityPreservingStrategy<T>
- Implements
- Inherited Members
Constructors
SimilarityPreservingStrategy(double, double, double)
public SimilarityPreservingStrategy(double similarityWeight = 0.5, double temperature = 3, double alpha = 0.3)
Parameters
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the distillation loss for backpropagation.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- Matrix<T>
The gradient of the loss with respect to student outputs. Shape: [batch_size x num_classes]
Remarks
For Implementers: Override this method to compute gradients for your strategy. The gradient should match the loss computation in ComputeLoss.
Batch Processing: Returns a gradient matrix with the same shape as the input, one gradient row for each sample in the batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the distillation loss between student and teacher batch outputs.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>The student model's output logits for a batch. Shape: [batch_size x num_classes]
teacherBatchOutputMatrix<T>The teacher model's output logits for a batch. Shape: [batch_size x num_classes]
trueLabelsBatchMatrix<T>Ground truth labels for the batch (optional). Shape: [batch_size x num_classes]
Returns
- T
The computed distillation loss value (scalar) for the batch.
Remarks
For Implementers: Override this method to define your strategy's loss computation. The base class handles temperature and alpha; you focus on the loss calculation logic.
Batch Processing: The loss should be computed over all samples in the batch and typically averaged. Each row in the input matrices represents one sample.
ComputeSimilarityLoss(Vector<T>[], Vector<T>[])
public T ComputeSimilarityLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings)
Parameters
studentEmbeddingsVector<T>[]teacherEmbeddingsVector<T>[]
Returns
- T