Table of Contents

Class ContrastiveDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Implements Contrastive Representation Distillation (CRD) which transfers knowledge through contrastive learning of sample relationships rather than just matching outputs.

public class ContrastiveDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
ContrastiveDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Beginners: Contrastive distillation teaches the student to understand which samples are similar and which are different, not just to copy the teacher's predictions. It's like learning to group things by their similarities rather than just memorizing labels.

Real-world Analogy: Instead of just teaching a student "This is a dog," you teach them "Dogs are more similar to wolves than to cats" and "Retrievers are more similar to Labs than to Chihuahuas." This relational understanding helps the student generalize better to new examples.

How CRD Works: 1. Extract embeddings/features from teacher and student 2. For each sample (anchor), identify: - Positive samples: Same class or similar features - Negative samples: Different class or dissimilar features 3. Pull student embeddings of anchor and positives together 4. Push student embeddings of anchor and negatives apart 5. Ensure student's embedding space has same structure as teacher's

Key Differences from Standard Distillation: - **Standard**: Match output probabilities [0.1, 0.7, 0.2] - **CRD**: Match embedding similarities and distances - **Benefit**: Better generalization, especially for few-shot learning

Mathematical Foundation: CRD uses InfoNCE loss (Noise Contrastive Estimation): L = -log(exp(sim(t_i, s_i)/τ) / Σ_j exp(sim(t_i, s_j)/τ)) where: - t_i, s_i are teacher/student embeddings of sample i - τ is temperature - j ranges over all samples in batch

Benefits: - **Better Features**: Student learns richer representations - **Few-Shot Learning**: Transfers better to new classes - **Robustness**: Less sensitive to label noise - **Interpretability**: Embedding space is more structured - **Complementary**: Can combine with standard distillation

Use Cases: - Few-shot/zero-shot learning - Transfer learning across domains - Learning with noisy labels - Metric learning tasks (face recognition, image retrieval) - Self-supervised pre-training

Performance Improvements: - CRD often gives 2-4% better accuracy than standard distillation - Particularly strong for small student models - Excellent for tasks requiring good embeddings

References: - Tian et al. (2020). Contrastive Representation Distillation. ICLR. - Chen et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations. ICML.

Constructors

ContrastiveDistillationStrategy(string, double, double, double, int, ContrastiveMode)

Initializes a new instance of the ContrastiveDistillationStrategy class.

public ContrastiveDistillationStrategy(string embeddingLayerName = "embeddings", double contrastiveWeight = 0.8, double temperature = 0.07, double alpha = 0.2, int negativesSampleSize = 1024, ContrastiveMode mode = ContrastiveMode.NTXent)

Parameters

embeddingLayerName string

Name of the layer to extract embeddings from (e.g., "layer_before_classifier").

contrastiveWeight double

Weight for contrastive loss vs standard output loss (default: 0.8).

temperature double

Temperature for contrastive softmax (default: 0.07).

alpha double

Balance between hard and soft loss (default: 0.2).

negativesSampleSize int

Number of negative samples to use (default: 1024).

mode ContrastiveMode

Contrastive mode (default: NTXent for label-free operation).

Remarks

For Beginners: Configure how much to weight contrastive learning:

- contrastiveWeight 0.6-0.9: More focus on learning representations - temperature 0.05-0.1: Lower = sharper distinctions between similar/dissimilar - negativesSampleSize 512-2048: More negatives = better discrimination

Example:

var strategy = new ContrastiveDistillationStrategy<double>(
    embeddingLayerName: "embedding_layer",
    contrastiveWeight: 0.8,  // 80% contrastive, 20% standard
    temperature: 0.07,        // Standard for contrastive learning
    alpha: 0.2,              // Mostly teacher knowledge
    negativesSampleSize: 1024 // Large negative set for better discrimination
);

Methods

ComputeContrastiveLoss(Vector<T>[], Vector<T>[], int[])

Computes contrastive loss on embeddings/features.

public T ComputeContrastiveLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings, int[] labels)

Parameters

studentEmbeddings Vector<T>[]

Student embeddings for batch.

teacherEmbeddings Vector<T>[]

Teacher embeddings for batch.

labels int[]

Sample labels for determining positives/negatives.

Returns

T

Contrastive loss value.

Remarks

For Beginners: This measures how well the student's embedding space matches the teacher's structural relationships. Lower loss means better match.

The embeddings should be from intermediate layers (not final outputs), as those contain richer representation information.

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes gradient of standard loss.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

Matrix<T>

ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)

Computes gradients of intermediate activation loss with respect to student embeddings.

public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations (must include embedding layer).

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations (must include embedding layer).

Returns

IntermediateActivations<T>

Gradients for the embedding layer (already weighted by contrastiveWeight).

Remarks

Computes analytical gradients for NT-Xent loss with respect to student embeddings. Uses cosine similarity gradients derived from the quotient rule.

For NT-Xent loss L_i = -sim(s_i, t_i)/τ + log(Σ_j exp(sim(s_i, t_j)/τ)), the gradient involves both positive and negative pair contributions.

ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)

Computes intermediate activation loss by matching embedding space structure between teacher and student.

public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations (must include embedding layer).

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations (must include embedding layer).

Returns

T

The contrastive loss (already weighted by contrastiveWeight).

Remarks

This implements the IIntermediateActivationStrategy interface to properly integrate contrastive learning into the training loop. The loss is computed from embeddings stored in the intermediate activations for the layer specified in the constructor.

Uses NTXent mode by default as it doesn't require labels. Each teacher-student pair for the same sample is treated as a positive pair, while other samples are negatives.

If the embedding layer is not found, returns zero loss.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes standard output loss (contrastive loss computed separately on embeddings).

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

T