Interface IIntermediateActivationStrategy<T>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Defines methods for distillation strategies that utilize intermediate layer activations.
public interface IIntermediateActivationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
Remarks
For Beginners: Some advanced distillation strategies don't just compare final outputs. They also compare what's happening inside the models at intermediate layers. This interface is for those advanced strategies.
Example Strategies Needing Intermediate Activations: - Feature-Based Distillation (FitNets): Match intermediate layer features between teacher and student - Attention Transfer: Transfer attention patterns from internal layers - Neuron Selectivity: Match how individual neurons respond across batches - Relational Knowledge Distillation: Transfer relationships between layer activations
Why Separate Interface? Not all strategies need intermediate activations. Simple response-based distillation only needs final outputs. This interface is optional - only implement it if your strategy needs access to internal layer outputs.
Usage Pattern:
// Strategy that needs both final outputs AND intermediate activations
public class MyAdvancedStrategy<T> : DistillationStrategyBase<T>, IIntermediateActivationStrategy<T>
{
// Implement standard loss/gradient for final outputs
public override T ComputeLoss(Matrix<T> studentBatch, Matrix<T> teacherBatch, Matrix<T> labels) { ... }
public override Matrix<T> ComputeGradient(...) { ... }
// Implement intermediate activation loss
public T ComputeIntermediateLoss(
IntermediateActivations<T> studentActivations,
IntermediateActivations<T> teacherActivations) { ... }
}
Methods
ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)
Computes gradients of the intermediate activation loss with respect to student activations.
IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate activations for a batch.
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate activations for a batch.
Returns
- IntermediateActivations<T>
Gradients for each intermediate layer (same structure as studentIntermediateActivations).
Remarks
For Beginners: This method computes how much to adjust each neuron's activation in the student's intermediate layers to better match the teacher. These gradients get backpropagated through the student network during training.
How It's Used: After computing intermediate loss, the trainer needs gradients to update the student model. This method provides those gradients layer by layer.
Example Calculation:
// In training loop (backward pass)
var intermediateGradients = advancedStrategy.ComputeIntermediateGradient(
studentResult.IntermediateActivations,
teacherResult.IntermediateActivations);
// Backpropagate through student network using these gradients
student.BackpropagateFromIntermediateLayers(intermediateGradients);
Implementation Requirements: - Return gradients only for layers where loss was computed - Gradient matrices must match activation dimensions (batch x features) - Gradients should be already weighted (include strategy weight in computation) - Return empty IntermediateActivations if no gradients (edge case)
Mathematical Notes: For MSE loss on layer activations: ∂L/∂student = (student - teacher) / batchSize For other losses, compute derivative analytically and return proper gradients.
ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)
Computes a loss component based on intermediate layer activations for a batch.
T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate activations for a batch.
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate activations for a batch.
Returns
- T
The computed intermediate activation loss component.
Remarks
For Beginners: This method compares what's happening inside the teacher and student models, not just their final outputs. For example: - Are neurons in layer 3 responding similarly? - Do attention patterns match in the middle layers? - Are feature representations aligned?
How It's Used: The trainer collects intermediate activations during forward passes, then calls this method to compute an additional loss component. This loss is combined with the standard output loss.
Example Calculation:
// In training loop
var teacherResult = teacher.Forward(inputBatch, collectIntermediateActivations: true);
var studentResult = student.Forward(inputBatch, collectIntermediateActivations: true);
// Standard loss (final outputs)
T outputLoss = strategy.ComputeLoss(studentResult.FinalOutput, teacherResult.FinalOutput, labels);
// Intermediate loss (internal layers) - only if strategy implements this interface
T intermediateLoss = 0;
if (strategy is IIntermediateActivationStrategy<T> advancedStrategy)
{
intermediateLoss = advancedStrategy.ComputeIntermediateLoss(
studentResult.IntermediateActivations,
teacherResult.IntermediateActivations);
}
// Total loss
T totalLoss = outputLoss + intermediateLoss;
Implementation Tips: - Match layers by name (e.g., "conv1", "layer3") - Handle mismatched architectures gracefully (teacher has more layers than student) - Consider weighting different layers differently (early layers vs. late layers) - Normalize activations if needed (different layers have different scales)