Class ContrastiveDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Implements Contrastive Representation Distillation (CRD) which transfers knowledge through contrastive learning of sample relationships rather than just matching outputs.
public class ContrastiveDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
ContrastiveDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Beginners: Contrastive distillation teaches the student to understand which samples are similar and which are different, not just to copy the teacher's predictions. It's like learning to group things by their similarities rather than just memorizing labels.
Real-world Analogy: Instead of just teaching a student "This is a dog," you teach them "Dogs are more similar to wolves than to cats" and "Retrievers are more similar to Labs than to Chihuahuas." This relational understanding helps the student generalize better to new examples.
How CRD Works: 1. Extract embeddings/features from teacher and student 2. For each sample (anchor), identify: - Positive samples: Same class or similar features - Negative samples: Different class or dissimilar features 3. Pull student embeddings of anchor and positives together 4. Push student embeddings of anchor and negatives apart 5. Ensure student's embedding space has same structure as teacher's
Key Differences from Standard Distillation: - **Standard**: Match output probabilities [0.1, 0.7, 0.2] - **CRD**: Match embedding similarities and distances - **Benefit**: Better generalization, especially for few-shot learning
Mathematical Foundation: CRD uses InfoNCE loss (Noise Contrastive Estimation): L = -log(exp(sim(t_i, s_i)/τ) / Σ_j exp(sim(t_i, s_j)/τ)) where: - t_i, s_i are teacher/student embeddings of sample i - τ is temperature - j ranges over all samples in batch
Benefits: - **Better Features**: Student learns richer representations - **Few-Shot Learning**: Transfers better to new classes - **Robustness**: Less sensitive to label noise - **Interpretability**: Embedding space is more structured - **Complementary**: Can combine with standard distillation
Use Cases: - Few-shot/zero-shot learning - Transfer learning across domains - Learning with noisy labels - Metric learning tasks (face recognition, image retrieval) - Self-supervised pre-training
Performance Improvements: - CRD often gives 2-4% better accuracy than standard distillation - Particularly strong for small student models - Excellent for tasks requiring good embeddings
References: - Tian et al. (2020). Contrastive Representation Distillation. ICLR. - Chen et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations. ICML.
Constructors
ContrastiveDistillationStrategy(string, double, double, double, int, ContrastiveMode)
Initializes a new instance of the ContrastiveDistillationStrategy class.
public ContrastiveDistillationStrategy(string embeddingLayerName = "embeddings", double contrastiveWeight = 0.8, double temperature = 0.07, double alpha = 0.2, int negativesSampleSize = 1024, ContrastiveMode mode = ContrastiveMode.NTXent)
Parameters
embeddingLayerNamestringName of the layer to extract embeddings from (e.g., "layer_before_classifier").
contrastiveWeightdoubleWeight for contrastive loss vs standard output loss (default: 0.8).
temperaturedoubleTemperature for contrastive softmax (default: 0.07).
alphadoubleBalance between hard and soft loss (default: 0.2).
negativesSampleSizeintNumber of negative samples to use (default: 1024).
modeContrastiveModeContrastive mode (default: NTXent for label-free operation).
Remarks
For Beginners: Configure how much to weight contrastive learning:
- contrastiveWeight 0.6-0.9: More focus on learning representations - temperature 0.05-0.1: Lower = sharper distinctions between similar/dissimilar - negativesSampleSize 512-2048: More negatives = better discrimination
Example:
var strategy = new ContrastiveDistillationStrategy<double>(
embeddingLayerName: "embedding_layer",
contrastiveWeight: 0.8, // 80% contrastive, 20% standard
temperature: 0.07, // Standard for contrastive learning
alpha: 0.2, // Mostly teacher knowledge
negativesSampleSize: 1024 // Large negative set for better discrimination
);
Methods
ComputeContrastiveLoss(Vector<T>[], Vector<T>[], int[])
Computes contrastive loss on embeddings/features.
public T ComputeContrastiveLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings, int[] labels)
Parameters
studentEmbeddingsVector<T>[]Student embeddings for batch.
teacherEmbeddingsVector<T>[]Teacher embeddings for batch.
labelsint[]Sample labels for determining positives/negatives.
Returns
- T
Contrastive loss value.
Remarks
For Beginners: This measures how well the student's embedding space matches the teacher's structural relationships. Lower loss means better match.
The embeddings should be from intermediate layers (not final outputs), as those contain richer representation information.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes gradient of standard loss.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)
Computes gradients of intermediate activation loss with respect to student embeddings.
public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations (must include embedding layer).
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations (must include embedding layer).
Returns
- IntermediateActivations<T>
Gradients for the embedding layer (already weighted by contrastiveWeight).
Remarks
Computes analytical gradients for NT-Xent loss with respect to student embeddings. Uses cosine similarity gradients derived from the quotient rule.
For NT-Xent loss L_i = -sim(s_i, t_i)/τ + log(Σ_j exp(sim(s_i, t_j)/τ)), the gradient involves both positive and negative pair contributions.
ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)
Computes intermediate activation loss by matching embedding space structure between teacher and student.
public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations (must include embedding layer).
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations (must include embedding layer).
Returns
- T
The contrastive loss (already weighted by contrastiveWeight).
Remarks
This implements the IIntermediateActivationStrategy interface to properly integrate contrastive learning into the training loop. The loss is computed from embeddings stored in the intermediate activations for the layer specified in the constructor.
Uses NTXent mode by default as it doesn't require labels. Each teacher-student pair for the same sample is treated as a positive pair, while other samples are negatives.
If the embedding layer is not found, returns zero loss.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes standard output loss (contrastive loss computed separately on embeddings).
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T