Class RelationalDistillationStrategy<T>
- Namespace
- AiDotNet.KnowledgeDistillation.Strategies
- Assembly
- AiDotNet.dll
public class RelationalDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>
Type Parameters
T
- Inheritance
-
RelationalDistillationStrategy<T>
- Implements
- Inherited Members
Constructors
RelationalDistillationStrategy(double, double, double, double, int, RelationalDistanceMetric)
Initializes a new instance of the RelationalDistillationStrategy class.
public RelationalDistillationStrategy(double distanceWeight = 1, double angleWeight = 2, double temperature = 3, double alpha = 0.3, int maxSamplesPerBatch = 32, RelationalDistanceMetric distanceMetric = RelationalDistanceMetric.Euclidean)
Parameters
distanceWeightdoubleWeight for distance-wise relation loss (default: 1.0).
angleWeightdoubleWeight for angle-wise relation loss (default: 2.0).
temperaturedoubleTemperature for softmax scaling (default: 3.0).
alphadoubleBalance between hard and soft loss (default: 0.3).
maxSamplesPerBatchintMax samples to consider for relations (default: 32).
distanceMetricRelationalDistanceMetricDistance metric to use (default: Euclidean).
Remarks
For Beginners: Configure how much to weight different types of relations: - distanceWeight: How much to preserve pairwise distances - angleWeight: How much to preserve triplet angles - Higher weights = stronger enforcement of that relation type
Example:
var strategy = new RelationalDistillationStrategy<double>(
distanceWeight: 25.0, // Strong distance preservation
angleWeight: 50.0, // Very strong angle preservation (recommended higher)
temperature: 3.0,
alpha: 0.3,
maxSamplesPerBatch: 16 // Limit for efficiency
);
Weight Selection Guidelines: - **Distance-wise**: 10-50 (Park et al. used 25) - **Angle-wise**: 20-100 (Park et al. used 50, typically 2× distance) - **Max samples**: 16-32 for efficiency (full batch is O(n³) for angles)
Balancing with Output Loss: The final loss combines: 1. Standard output distillation (α × hard + (1-α) × soft) 2. Distance-wise relational loss × distanceWeight 3. Angle-wise relational loss × angleWeight Tune weights so each component contributes meaningfully.
Methods
ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes gradient of combined output loss and relational loss.
public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- Matrix<T>
Remarks
The gradient includes both the standard distillation gradient and the relational gradient computed from pairwise distances and angular relationships in the accumulated batch.
ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)
Computes combined output loss and relational loss.
public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)
Parameters
studentBatchOutputMatrix<T>teacherBatchOutputMatrix<T>trueLabelsBatchMatrix<T>
Returns
- T
Remarks
This method accumulates student/teacher outputs and computes relational loss when a batch is complete. The relational loss is then amortized across subsequent samples.
CRITICAL: Call Reset() at epoch boundaries to prevent buffer leakage.
ComputeRelationalLoss(Vector<T>[], Vector<T>[])
Computes relational knowledge distillation loss for a batch of samples.
public T ComputeRelationalLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings)
Parameters
studentEmbeddingsVector<T>[]Student's output embeddings/features for batch.
teacherEmbeddingsVector<T>[]Teacher's output embeddings/features for batch.
Returns
- T
Combined relational loss (distance + angle).
Remarks
For Beginners: This computes how well the student preserves the teacher's relational structure. Pass in the embeddings/features (not final classifications).
The loss has two components: 1. **Distance-wise**: Do pairs of samples have similar distances? 2. **Angle-wise**: Do triplets of samples have similar angular relationships?
Example usage:
// Get embeddings from penultimate layer
var studentEmbeds = new Vector<double>[] { student.GetEmbedding(x1), student.GetEmbedding(x2), ... };
var teacherEmbeds = new Vector<double>[] { teacher.GetEmbedding(x1), teacher.GetEmbedding(x2), ... };
var relationalLoss = strategy.ComputeRelationalLoss(studentEmbeds, teacherEmbeds);
Reset()
Resets the strategy's internal state.
public void Reset()
Remarks
Note: With Matrix<T> batch processing, this strategy no longer maintains state between calls, so Reset() is a no-op. It's kept for compatibility with the trainer's OnEpochEnd() which calls Reset() on all strategies.