Table of Contents

Class FeatureDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation
Assembly
AiDotNet.dll

Implements feature-based knowledge distillation (FitNets) where the student learns to match the teacher's intermediate layer representations.

public class FeatureDistillationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
FeatureDistillationStrategy<T>
Inherited Members

Remarks

For Beginners: While standard distillation transfers knowledge through final outputs, feature distillation goes deeper by matching intermediate layer activations. This helps the student learn not just what the teacher predicts, but how it thinks.

Why Feature Distillation? - **Better for Different Architectures**: When student and teacher have very different structures - **Richer Knowledge Transfer**: Captures hierarchical feature learning - **Improved Generalization**: Student learns more robust representations - **Complementary to Response Distillation**: Can be combined with standard distillation

Real-world Analogy: Imagine learning to paint from a master artist. Standard distillation is like copying only the final painting. Feature distillation is like watching the master's brush strokes, color mixing, and layering techniques - learning the process, not just the result.

How It Works: 1. Extract features from a teacher layer (e.g., conv3 in ResNet) 2. Extract features from corresponding student layer 3. Minimize MSE (Mean Squared Error) between them 4. Optionally use a projection layer if dimensions don't match

Common Applications: - ResNet → MobileNet: Match convolutional feature maps - BERT → DistilBERT: Match transformer layer outputs - Teacher and student with different widths/depths

References: - Romero, A., et al. (2014). FitNets: Hints for Thin Deep Nets. arXiv:1412.6550

Constructors

FeatureDistillationStrategy(string[], double)

Initializes a new instance of the FeatureDistillationStrategy class.

public FeatureDistillationStrategy(string[] layerPairs, double featureWeight = 0.5)

Parameters

layerPairs string[]

Names of layer pairs to match (teacher_layer:student_layer format). Example: ["conv3:conv2", "conv4:conv3"]

featureWeight double

Weight for feature loss vs. output loss (default 0.5). Higher values emphasize matching intermediate features.

Remarks

For Beginners: Layer pairs specify which teacher layers should be matched to which student layers. Format: "teacher_layer_name:student_layer_name"

Example usage:

// Match teacher's layer 3 to student's layer 2, and teacher's 4 to student's 3
var strategy = new FeatureDistillationStrategy<double>(
    layerPairs: new[] { "layer3:layer2", "layer4:layer3" },
    featureWeight: 0.5  // Equal weight to features and outputs
);

Tips for choosing layer pairs: - Match semantically similar layers (similar depth in network) - Start with 1-2 pairs, add more if needed - Earlier layers: low-level features (edges, textures) - Later layers: high-level features (objects, concepts)

Methods

ComputeFeatureGradient(Vector<T>, Vector<T>)

Computes the gradient of feature loss for backpropagation.

public Vector<T> ComputeFeatureGradient(Vector<T> studentFeatures, Vector<T> teacherFeatures)

Parameters

studentFeatures Vector<T>

Student's feature vector.

teacherFeatures Vector<T>

Teacher's feature vector.

Returns

Vector<T>

Gradient vector for backpropagation.

Remarks

For Beginners: The gradient of MSE is simple: ∂MSE/∂student = (2/n) × (student - teacher)

This tells us: if student feature is too high, decrease it; if too low, increase it.

ComputeFeatureLoss(Func<string, Vector<T>>, Func<string, Vector<T>>, Vector<T>)

Computes the feature matching loss between student and teacher intermediate representations.

public T ComputeFeatureLoss(Func<string, Vector<T>> teacherFeatureExtractor, Func<string, Vector<T>> studentFeatureExtractor, Vector<T> input)

Parameters

teacherFeatureExtractor Func<string, Vector<T>>

Function to extract teacher features for a layer name.

studentFeatureExtractor Func<string, Vector<T>>

Function to extract student features for a layer name.

input Vector<T>

Input data for forward pass.

Returns

T

Mean squared error between matched feature pairs.

Remarks

For Beginners: This computes how different the student's internal features are from the teacher's. Lower loss means the student is learning to think like the teacher.

The loss is computed as: L_feature = (1/N) × Σ MSE(teacher_features_i, student_features_i) where N is the number of layer pairs.

If feature dimensions don't match, consider adding a projection layer (simple linear transformation) to the student.