Class AttentionDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Implements attention-based knowledge distillation for transformer models. Transfers knowledge through attention patterns rather than just final outputs.
public class AttentionDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
AttentionDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Beginners: Attention mechanisms in transformers tell us "what the model is focusing on." Instead of just copying the teacher's final answers, attention distillation teaches the student to focus on the same things the teacher focuses on.
Real-world Analogy: Imagine learning to play chess from a grandmaster. Instead of just copying their moves (outputs), you also learn where they look on the board and what pieces they pay attention to. This deeper understanding helps you think like the master, not just mimic their moves.
Why Attention Distillation? - **Richer Knowledge**: Attention patterns reveal reasoning process - **Better for Transformers**: Transformers rely heavily on attention - **Interpretability**: Can see what student learned to focus on - **Complementary**: Works with response-based distillation
How It Works: 1. Extract attention weights from teacher layers 2. Extract attention weights from student layers 3. Minimize MSE between attention distributions 4. Combine with standard output distillation loss
Attention Matching Strategies: - **Layer-wise**: Match corresponding layers (layer 6→layer 3) - **Head-wise**: Match individual attention heads - **Global**: Match averaged attention across all heads - **Selective**: Match only the most important heads
Common Applications: - **DistilBERT**: Used attention distillation to compress BERT - **TinyBERT**: Attention transfer + representation transfer - **MobileBERT**: Layer-wise attention matching - **Vision Transformers**: Attention distillation for ViT compression
Benefits: - Preserves model's "reasoning" process - Improves student's interpretability - Often yields 2-5% better accuracy than output-only distillation - Helps with few-shot and zero-shot transfer
References: - Sanh et al. (2019). DistilBERT: A Distilled Version of BERT. arXiv:1910.01108 - Jiao et al. (2020). TinyBERT: Distilling BERT for Natural Language Understanding. EMNLP. - Wang et al. (2020). MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression.
Constructors
AttentionDistillationStrategy(string[], double, double, double, AttentionMatchingMode)
Initializes a new instance of the AttentionDistillationStrategy class.
public AttentionDistillationStrategy(string[] attentionLayers, double attentionWeight = 0.3, double temperature = 3, double alpha = 0.3, AttentionMatchingMode matchingMode = AttentionMatchingMode.MSE)
Parameters
attentionLayersstring[]Names of attention layers to match (e.g., ["layer.0.attention", "layer.1.attention"]).
attentionWeightdoubleWeight for attention loss vs. output loss (default: 0.3).
temperaturedoubleTemperature for softmax scaling (default: 3.0).
alphadoubleBalance between hard and soft loss (default: 0.3).
matchingModeAttentionMatchingModeHow to match attention patterns (default: MSE).
Remarks
For Beginners: Specify which attention layers to match and how much weight to give to attention matching vs. output matching.
Example for BERT-like model:
var strategy = new AttentionDistillationStrategy<double>(
attentionLayers: new[] {
"encoder.layer.0.attention",
"encoder.layer.3.attention",
"encoder.layer.6.attention"
},
attentionWeight: 0.3, // 30% attention, 70% output
temperature: 2.0,
alpha: 0.5
);
Layer Selection Tips: - **Early layers**: Low-level patterns (syntax, local features) - **Middle layers**: Mid-level concepts (phrases, object parts) - **Late layers**: High-level semantics (meaning, objects) - **All layers**: Most comprehensive but computationally expensive - **Selective**: Match 2-3 key layers for efficiency
Weight Selection: - 0.1-0.2: Slight attention guidance - 0.3-0.4: Balanced (recommended for most cases) - 0.5-0.7: Strong attention focus - 0.8+: Primarily attention-driven (risky)
Methods
ComputeAttentionLoss(Func<string, Vector<T>>, Func<string, Vector<T>>)
Computes attention matching loss between teacher and student attention patterns.
public T ComputeAttentionLoss(Func<string, Vector<T>> teacherAttentionExtractor, Func<string, Vector<T>> studentAttentionExtractor)
Parameters
teacherAttentionExtractorFunc<string, Vector<T>>Function to extract teacher attention for a layer.
studentAttentionExtractorFunc<string, Vector<T>>Function to extract student attention for a layer.
Returns
- T
Attention matching loss.
Remarks
For Beginners: This measures how different the attention patterns are. Lower loss means student is focusing on the same things as the teacher.
The extractors should return attention weights as vectors, typically flattened from [num_heads, seq_len, seq_len] matrices.
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes gradient of the combined loss.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)
Computes gradients of intermediate activation loss with respect to student activations.
public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations.
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations.
Returns
- IntermediateActivations<T>
Gradients for each attention layer (already weighted by attentionWeight).
ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)
Computes intermediate activation loss by matching attention patterns between teacher and student.
public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations (must include attention layers).
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations (must include attention layers).
Returns
- T
The attention matching loss (already weighted by attentionWeight).
Remarks
This implements the IIntermediateActivationStrategy interface to properly integrate attention matching into the training loop. The loss is computed from attention patterns stored in the intermediate activations for layers specified in the constructor.
If any target layer is not found, it is skipped. Returns zero if no layers are found.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes combined distillation loss (output loss + attention loss).
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>Student batch output [batchSize x outputDim].
teacherBatchOutputMatrix<T>Teacher batch output [batchSize x outputDim].
trueLabelsBatchMatrix<T>Optional batch labels [batchSize x outputDim].
Returns
- T
Average loss across the batch.
Remarks
For Beginners: This combines two types of loss: 1. Standard distillation loss on final outputs 2. Attention matching loss on intermediate attention patterns
Formula: L = (1 - w) × L_output + w × L_attention where w is attentionWeight.