Table of Contents

Class NeuronSelectivityDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll

Neuron selectivity distillation that transfers the activation patterns and selectivity of individual neurons.

public class NeuronSelectivityDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>, IIntermediateActivationStrategy<T>

Type Parameters

T

The numeric type for calculations (e.g., double, float).

Inheritance
NeuronSelectivityDistillationStrategy<T>
Implements
Inherited Members

Remarks

For Production Use: This strategy focuses on matching how individual neurons respond to inputs. Some neurons are highly selective (activate strongly for specific patterns), while others are more general. Transferring this selectivity helps the student learn meaningful feature representations.

Key Concept: Neuron selectivity measures how discriminative each neuron is. A highly selective neuron activates strongly for certain inputs and weakly for others. The distribution of selectivity across neurons is important for model performance.

Implementation: We measure selectivity using: 1. Activation variance (how much neuron output varies across samples) 2. Sparsity (what percentage of time the neuron is active) 3. Peak-to-average ratio (how peaked the activation distribution is)

Usage Pattern: This strategy implements both standard output-based distillation and intermediate activation-based selectivity matching. Use as follows:

Standard Usage (via IDistillationStrategy):

T outputLoss = strategy.ComputeLoss(studentOutput, teacherOutput, labels);
Matrix<T> outputGrad = strategy.ComputeGradient(studentOutput, teacherOutput, labels);

With Intermediate Activations (via IIntermediateActivationStrategy):

// Collect activations during forward pass
var studentActivations = new IntermediateActivations<T>();
studentActivations.Add("layer3", studentLayer3Output);

var teacherActivations = new IntermediateActivations<T>();
teacherActivations.Add("layer3", teacherLayer3Output);

// Compute combined loss
T outputLoss = strategy.ComputeLoss(studentOutput, teacherOutput, labels);
T selectivityLoss = strategy.ComputeIntermediateLoss(studentActivations, teacherActivations);
T totalLoss = outputLoss + selectivityLoss; // selectivityLoss is already weighted

The selectivityWeight and metric parameters control the intermediate activation loss component.

Constructors

NeuronSelectivityDistillationStrategy(string, double, SelectivityMetric, double, double)

public NeuronSelectivityDistillationStrategy(string targetLayerName = "default", double selectivityWeight = 0.5, SelectivityMetric metric = SelectivityMetric.Variance, double temperature = 3, double alpha = 0.3)

Parameters

targetLayerName string
selectivityWeight double
metric SelectivityMetric
temperature double
alpha double

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the gradient of the base distillation loss on final outputs.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

Matrix<T>

Remarks

This method implements standard distillation gradient (soft + hard) on final outputs. It does NOT include the selectivity gradient, which requires intermediate activations. Selectivity gradients must be computed separately and backpropagated through the network. See class remarks for usage pattern.

ComputeIntermediateGradient(IntermediateActivations<T>, IntermediateActivations<T>)

Computes gradients of intermediate activation loss with respect to student activations.

public IntermediateActivations<T> ComputeIntermediateGradient(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations.

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations.

Returns

IntermediateActivations<T>

Gradients for the target layer (already weighted by selectivityWeight).

Remarks

Computes analytical gradients for selectivity loss based on the chosen metric. For Variance metric, uses ∂var/∂a = (2/B) * (a - mean). For Sparsity and PeakToAverage, uses numerical approximation due to non-differentiable operations.

ComputeIntermediateLoss(IntermediateActivations<T>, IntermediateActivations<T>)

Computes intermediate activation loss by matching neuron selectivity between teacher and student.

public T ComputeIntermediateLoss(IntermediateActivations<T> studentIntermediateActivations, IntermediateActivations<T> teacherIntermediateActivations)

Parameters

studentIntermediateActivations IntermediateActivations<T>

Student's intermediate layer activations.

teacherIntermediateActivations IntermediateActivations<T>

Teacher's intermediate layer activations.

Returns

T

The selectivity matching loss (already weighted by selectivityWeight).

Remarks

This implements the IIntermediateActivationStrategy interface to properly integrate selectivity matching into the training loop. The loss is computed from the activations of the layer specified by targetLayerName in the constructor.

If the target layer is not found in the activation dictionaries, returns zero loss.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes the base distillation loss on final outputs.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

T

Remarks

This method implements standard distillation loss (soft + hard loss) on final outputs. It does NOT include the selectivity component, which requires intermediate activations. Use ComputeSelectivityLoss separately and combine the losses manually. See class remarks for usage pattern.

ComputeSelectivityLoss(Vector<T>[], Vector<T>[])

Computes neuron selectivity loss by comparing activation patterns across a batch.

public T ComputeSelectivityLoss(Vector<T>[] studentActivations, Vector<T>[] teacherActivations)

Parameters

studentActivations Vector<T>[]

Student neuron activations for a batch [batchSize x numNeurons].

teacherActivations Vector<T>[]

Teacher neuron activations for a batch [batchSize x numNeurons].

Returns

T

Selectivity matching loss.

Remarks

This should be called with intermediate layer activations, not final outputs. Collect activations for an entire batch, then call this method.