Table of Contents

Class AttentionDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Implements attention-based knowledge distillation for transformer models. Transfers knowledge through attention patterns rather than just final outputs.

public class AttentionDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
AttentionDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Beginners: Attention mechanisms in transformers tell us "what the model is focusing on." Instead of just copying the teacher's final answers, attention distillation teaches the student to focus on the same things the teacher focuses on.

Real-world Analogy: Imagine learning to play chess from a grandmaster. Instead of just copying their moves (outputs), you also learn where they look on the board and what pieces they pay attention to. This deeper understanding helps you think like the master, not just mimic their moves.

Why Attention Distillation? - **Richer Knowledge**: Attention patterns reveal reasoning process - **Better for Transformers**: Transformers rely heavily on attention - **Interpretability**: Can see what student learned to focus on - **Complementary**: Works with response-based distillation

How It Works: 1. Extract attention weights from teacher layers 2. Extract attention weights from student layers 3. Minimize MSE between attention distributions 4. Combine with standard output distillation loss

Attention Matching Strategies: - **Layer-wise**: Match corresponding layers (layer 6→layer 3) - **Head-wise**: Match individual attention heads - **Global**: Match averaged attention across all heads - **Selective**: Match only the most important heads

Common Applications: - **DistilBERT**: Used attention distillation to compress BERT - **TinyBERT**: Attention transfer + representation transfer - **MobileBERT**: Layer-wise attention matching - **Vision Transformers**: Attention distillation for ViT compression

Benefits: - Preserves model's "reasoning" process - Improves student's interpretability - Often yields 2-5% better accuracy than output-only distillation - Helps with few-shot and zero-shot transfer

References: - Sanh et al. (2019). DistilBERT: A Distilled Version of BERT. arXiv:1910.01108 - Jiao et al. (2020). TinyBERT: Distilling BERT for Natural Language Understanding. EMNLP. - Wang et al. (2020). MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression.

Constructors

AttentionDistillationStrategy(string[], double, double, double, AttentionMatchingMode)

Initializes a new instance of the AttentionDistillationStrategy class.

public AttentionDistillationStrategy(string[] attentionLayers, double attentionWeight = 0.3, double temperature = 3, double alpha = 0.3, AttentionMatchingMode matchingMode = AttentionMatchingMode.MSE)

Parameters

attentionLayers string[]

Names of attention layers to match (e.g., ["layer.0.attention", "layer.1.attention"]).

attentionWeight double

Weight for attention loss vs. output loss (default: 0.3).

temperature double

Temperature for softmax scaling (default: 3.0).

alpha double

Balance between hard and soft loss (default: 0.3).

matchingMode AttentionMatchingMode

How to match attention patterns (default: MSE).

Remarks

For Beginners: Specify which attention layers to match and how much weight to give to attention matching vs. output matching.

Example for BERT-like model:

var strategy = new AttentionDistillationStrategy<double>(
    attentionLayers: new[] {
        "encoder.layer.0.attention",
        "encoder.layer.3.attention",
        "encoder.layer.6.attention"
    },
    attentionWeight: 0.3,  // 30% attention, 70% output
    temperature: 2.0,
    alpha: 0.5
);

Layer Selection Tips: - **Early layers**: Low-level patterns (syntax, local features) - **Middle layers**: Mid-level concepts (phrases, object parts) - **Late layers**: High-level semantics (meaning, objects) - **All layers**: Most comprehensive but computationally expensive - **Selective**: Match 2-3 key layers for efficiency

Weight Selection: - 0.1-0.2: Slight attention guidance - 0.3-0.4: Balanced (recommended for most cases) - 0.5-0.7: Strong attention focus - 0.8+: Primarily attention-driven (risky)

Methods

ComputeAttentionLoss(Func<string, Vector<T>>, Func<string, Vector<T>>)

Computes attention matching loss between teacher and student attention patterns.

public T ComputeAttentionLoss(Func<string, Vector<T>> teacherAttentionExtractor, Func<string, Vector<T>> studentAttentionExtractor)

Parameters

teacherAttentionExtractor Func<string, Vector<T>>

Function to extract teacher attention for a layer.

studentAttentionExtractor Func<string, Vector<T>>

Function to extract student attention for a layer.

Returns

T

Attention matching loss.

Remarks

For Beginners: This measures how different the attention patterns are. Lower loss means student is focusing on the same things as the teacher.

The extractors should return attention weights as vectors, typically flattened from [num_heads, seq_len, seq_len] matrices.

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes gradient of the combined loss.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

Matrix<T>

ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)

Computes gradients of intermediate activation loss with respect to student activations.

public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations.

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations.

Returns

IntermediateActivations<T>

Gradients for each attention layer (already weighted by attentionWeight).

ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)

Computes intermediate activation loss by matching attention patterns between teacher and student.

public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations (must include attention layers).

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations (must include attention layers).

Returns

T

The attention matching loss (already weighted by attentionWeight).

Remarks

This implements the IIntermediateActivationStrategy interface to properly integrate attention matching into the training loop. The loss is computed from attention patterns stored in the intermediate activations for layers specified in the constructor.

If any target layer is not found, it is skipped. Returns zero if no layers are found.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes combined distillation loss (output loss + attention loss).

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>

Student batch output [batchSize x outputDim].

teacherBatchOutput Matrix<T>

Teacher batch output [batchSize x outputDim].

trueLabelsBatch Matrix<T>

Optional batch labels [batchSize x outputDim].

Returns

T

Average loss across the batch.

Remarks

For Beginners: This combines two types of loss: 1. Standard distillation loss on final outputs 2. Attention matching loss on intermediate attention patterns

Formula: L = (1 - w) × L_output + w × L_attention where w is attentionWeight.