Class FeatureDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation
- Assembly
- AiDotNet.dll
Implements feature-based knowledge distillation (FitNets) where the student learns to match the teacher's intermediate layer representations.
public class FeatureDistillationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
FeatureDistillationStrategy<T>
- Inherited Members
Remarks
For Beginners: While standard distillation transfers knowledge through final outputs, feature distillation goes deeper by matching intermediate layer activations. This helps the student learn not just what the teacher predicts, but how it thinks.
Why Feature Distillation? - **Better for Different Architectures**: When student and teacher have very different structures - **Richer Knowledge Transfer**: Captures hierarchical feature learning - **Improved Generalization**: Student learns more robust representations - **Complementary to Response Distillation**: Can be combined with standard distillation
Real-world Analogy: Imagine learning to paint from a master artist. Standard distillation is like copying only the final painting. Feature distillation is like watching the master's brush strokes, color mixing, and layering techniques - learning the process, not just the result.
How It Works: 1. Extract features from a teacher layer (e.g., conv3 in ResNet) 2. Extract features from corresponding student layer 3. Minimize MSE (Mean Squared Error) between them 4. Optionally use a projection layer if dimensions don't match
Common Applications: - ResNet → MobileNet: Match convolutional feature maps - BERT → DistilBERT: Match transformer layer outputs - Teacher and student with different widths/depths
References: - Romero, A., et al. (2014). FitNets: Hints for Thin Deep Nets. arXiv:1412.6550
Constructors
FeatureDistillationStrategy(string[], double)
Initializes a new instance of the FeatureDistillationStrategy class.
public FeatureDistillationStrategy(string[] layerPairs, double featureWeight = 0.5)
Parameters
layerPairsstring[]Names of layer pairs to match (teacher_layer:student_layer format). Example: ["conv3:conv2", "conv4:conv3"]
featureWeightdoubleWeight for feature loss vs. output loss (default 0.5). Higher values emphasize matching intermediate features.
Remarks
For Beginners: Layer pairs specify which teacher layers should be matched to which student layers. Format: "teacher_layer_name:student_layer_name"
Example usage:
// Match teacher's layer 3 to student's layer 2, and teacher's 4 to student's 3
var strategy = new FeatureDistillationStrategy<double>(
layerPairs: new[] { "layer3:layer2", "layer4:layer3" },
featureWeight: 0.5 // Equal weight to features and outputs
);
Tips for choosing layer pairs: - Match semantically similar layers (similar depth in network) - Start with 1-2 pairs, add more if needed - Earlier layers: low-level features (edges, textures) - Later layers: high-level features (objects, concepts)
Methods
ComputeFeatureGradient(Vector<T>, Vector<T>)
Computes the gradient of feature loss for backpropagation.
public Vector<T> ComputeFeatureGradient(Vector<T> studentFeatures, Vector<T> teacherFeatures)
Parameters
studentFeaturesVector<T>Student's feature vector.
teacherFeaturesVector<T>Teacher's feature vector.
Returns
- Vector<T>
Gradient vector for backpropagation.
Remarks
For Beginners: The gradient of MSE is simple: ∂MSE/∂student = (2/n) × (student - teacher)
This tells us: if student feature is too high, decrease it; if too low, increase it.
ComputeFeatureLoss(Func<string, Vector<T>>, Func<string, Vector<T>>, Vector<T>)
Computes the feature matching loss between student and teacher intermediate representations.
public T ComputeFeatureLoss(Func<string, Vector<T>> teacherFeatureExtractor, Func<string, Vector<T>> studentFeatureExtractor, Vector<T> input)
Parameters
teacherFeatureExtractorFunc<string, Vector<T>>Function to extract teacher features for a layer name.
studentFeatureExtractorFunc<string, Vector<T>>Function to extract student features for a layer name.
inputVector<T>Input data for forward pass.
Returns
- T
Mean squared error between matched feature pairs.
Remarks
For Beginners: This computes how different the student's internal features are from the teacher's. Lower loss means the student is learning to think like the teacher.
The loss is computed as: L_feature = (1/N) × Σ MSE(teacher_features_i, student_features_i) where N is the number of layer pairs.
If feature dimensions don't match, consider adding a projection layer (simple linear transformation) to the student.