Class NeuronSelectivityDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
Neuron selectivity distillation that transfers the activation patterns and selectivity of individual neurons.
public class NeuronSelectivityDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>
Type Parameters
TThe numeric type for calculations (e.g., double, float).
- Inheritance
-
NeuronSelectivityDistillationStrategy<T>
- Implements
- Inherited Members
Remarks
For Production Use: This strategy focuses on matching how individual neurons respond to inputs. Some neurons are highly selective (activate strongly for specific patterns), while others are more general. Transferring this selectivity helps the student learn meaningful feature representations.
Key Concept: Neuron selectivity measures how discriminative each neuron is. A highly selective neuron activates strongly for certain inputs and weakly for others. The distribution of selectivity across neurons is important for model performance.
Implementation: We measure selectivity using: 1. Activation variance (how much neuron output varies across samples) 2. Sparsity (what percentage of time the neuron is active) 3. Peak-to-average ratio (how peaked the activation distribution is)
Usage Pattern: This strategy implements both standard output-based distillation and intermediate activation-based selectivity matching. Use as follows:
Standard Usage (via IDistillationStrategy):
T outputLoss = strategy.ComputeLoss(studentOutput, teacherOutput, labels);
Matrix<T> outputGrad = strategy.ComputeGradient(studentOutput, teacherOutput, labels);
With Intermediate Activations (via IIntermediateActivationStrategy):
// Collect activations during forward pass
var studentActivations = new IntermediateActivations<T>();
studentActivations.Add("layer3", studentLayer3Output);
var teacherActivations = new IntermediateActivations<T>();
teacherActivations.Add("layer3", teacherLayer3Output);
// Compute combined loss
T outputLoss = strategy.ComputeLoss(studentOutput, teacherOutput, labels);
T selectivityLoss = strategy.ComputeIntermediateLoss(studentActivations, teacherActivations);
T totalLoss = outputLoss + selectivityLoss; // selectivityLoss is already weighted
The selectivityWeight and metric parameters control the intermediate activation loss component.
Constructors
NeuronSelectivityDistillationStrategy(string, double, SelectivityMetric, double, double)
public NeuronSelectivityDistillationStrategy(string targetLayerName = "default", double selectivityWeight = 0.5, SelectivityMetric metric = SelectivityMetric.Variance, double temperature = 3, double alpha = 0.3)
Parameters
targetLayerNamestringselectivityWeightdoublemetricSelectivityMetrictemperaturedoublealphadouble
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the gradient of the base distillation loss on final outputs.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
Remarks
This method implements standard distillation gradient (soft + hard) on final outputs. It does NOT include the selectivity gradient, which requires intermediate activations. Selectivity gradients must be computed separately and backpropagated through the network. See class remarks for usage pattern.
ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)
Computes gradients of intermediate activation loss with respect to student activations.
public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations.
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations.
Returns
- IntermediateActivations<T>
Gradients for the target layer (already weighted by selectivityWeight).
Remarks
Computes analytical gradients for selectivity loss based on the chosen metric. For Variance metric, uses ∂var/∂a = (2/B) * (a - mean). For Sparsity and PeakToAverage, uses numerical approximation due to non-differentiable operations.
ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)
Computes intermediate activation loss by matching neuron selectivity between teacher and student.
public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)
Parameters
studentIntermediateActivationsIntermediateActivations<T>Student's intermediate layer activations.
teacherIntermediateActivationsIntermediateActivations<T>Teacher's intermediate layer activations.
Returns
- T
The selectivity matching loss (already weighted by selectivityWeight).
Remarks
This implements the IIntermediateActivationStrategy interface to properly integrate selectivity matching into the training loop. The loss is computed from the activations of the layer specified by targetLayerName in the constructor.
If the target layer is not found in the activation dictionaries, returns zero loss.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes the base distillation loss on final outputs.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T
Remarks
This method implements standard distillation loss (soft + hard loss) on final outputs. It does NOT include the selectivity component, which requires intermediate activations. Use ComputeSelectivityLoss separately and combine the losses manually. See class remarks for usage pattern.
ComputeSelectivityLoss(Vector<T>[], Vector<T>[])
Computes neuron selectivity loss by comparing activation patterns across a batch.
public T ComputeSelectivityLoss(Vector<T>[] studentActivations, Vector<T>[] teacherActivations)
Parameters
studentActivationsVector<T>[]Student neuron activations for a batch [batchSize x numNeurons].
teacherActivationsVector<T>[]Teacher neuron activations for a batch [batchSize x numNeurons].
Returns
- T
Selectivity matching loss.
Remarks
This should be called with intermediate layer activations, not final outputs. Collect activations for an entire batch, then call this method.