Table of Contents

Class RelationalDistillationStrategy<T>

Namespace
AiDotNet.KnowledgeDistillation.Strategies
Assembly
AiDotNet.dll
public class RelationalDistillationStrategy<T> : DistillationStrategyBase<T>, IDistillationStrategy<T>

Type Parameters

T
Inheritance
RelationalDistillationStrategy<T>
Implements
Inherited Members

Constructors

RelationalDistillationStrategy(double, double, double, double, int, RelationalDistanceMetric)

Initializes a new instance of the RelationalDistillationStrategy class.

public RelationalDistillationStrategy(double distanceWeight = 1, double angleWeight = 2, double temperature = 3, double alpha = 0.3, int maxSamplesPerBatch = 32, RelationalDistanceMetric distanceMetric = RelationalDistanceMetric.Euclidean)

Parameters

distanceWeight double

Weight for distance-wise relation loss (default: 1.0).

angleWeight double

Weight for angle-wise relation loss (default: 2.0).

temperature double

Temperature for softmax scaling (default: 3.0).

alpha double

Balance between hard and soft loss (default: 0.3).

maxSamplesPerBatch int

Max samples to consider for relations (default: 32).

distanceMetric RelationalDistanceMetric

Distance metric to use (default: Euclidean).

Remarks

For Beginners: Configure how much to weight different types of relations: - distanceWeight: How much to preserve pairwise distances - angleWeight: How much to preserve triplet angles - Higher weights = stronger enforcement of that relation type

Example:

var strategy = new RelationalDistillationStrategy<double>(
    distanceWeight: 25.0,  // Strong distance preservation
    angleWeight: 50.0,     // Very strong angle preservation (recommended higher)
    temperature: 3.0,
    alpha: 0.3,
    maxSamplesPerBatch: 16  // Limit for efficiency
);

Weight Selection Guidelines: - **Distance-wise**: 10-50 (Park et al. used 25) - **Angle-wise**: 20-100 (Park et al. used 50, typically 2× distance) - **Max samples**: 16-32 for efficiency (full batch is O(n³) for angles)

Balancing with Output Loss: The final loss combines: 1. Standard output distillation (α × hard + (1-α) × soft) 2. Distance-wise relational loss × distanceWeight 3. Angle-wise relational loss × angleWeight Tune weights so each component contributes meaningfully.

Methods

ComputeGradient(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes gradient of combined output loss and relational loss.

public override Matrix<T> ComputeGradient(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

Matrix<T>

Remarks

The gradient includes both the standard distillation gradient and the relational gradient computed from pairwise distances and angular relationships in the accumulated batch.

ComputeLoss(Matrix<T>, Matrix<T>, Matrix<T>?)

Computes combined output loss and relational loss.

public override T ComputeLoss(Matrix<T> studentBatchOutput, Matrix<T> teacherBatchOutput, Matrix<T>? trueLabelsBatch = null)

Parameters

studentBatchOutput Matrix<T>
teacherBatchOutput Matrix<T>
trueLabelsBatch Matrix<T>

Returns

T

Remarks

This method accumulates student/teacher outputs and computes relational loss when a batch is complete. The relational loss is then amortized across subsequent samples.

CRITICAL: Call Reset() at epoch boundaries to prevent buffer leakage.

ComputeRelationalLoss(Vector<T>[], Vector<T>[])

Computes relational knowledge distillation loss for a batch of samples.

public T ComputeRelationalLoss(Vector<T>[] studentEmbeddings, Vector<T>[] teacherEmbeddings)

Parameters

studentEmbeddings Vector<T>[]

Student's output embeddings/features for batch.

teacherEmbeddings Vector<T>[]

Teacher's output embeddings/features for batch.

Returns

T

Combined relational loss (distance + angle).

Remarks

For Beginners: This computes how well the student preserves the teacher's relational structure. Pass in the embeddings/features (not final classifications).

The loss has two components: 1. **Distance-wise**: Do pairs of samples have similar distances? 2. **Angle-wise**: Do triplets of samples have similar angular relationships?

Example usage:

// Get embeddings from penultimate layer
var studentEmbeds = new Vector<double>[] { student.GetEmbedding(x1), student.GetEmbedding(x2), ... };
var teacherEmbeds = new Vector<double>[] { teacher.GetEmbedding(x1), teacher.GetEmbedding(x2), ... };

var relationalLoss = strategy.ComputeRelationalLoss(studentEmbeds, teacherEmbeds);

Reset()

Resets the strategy's internal state.

public void Reset()

Remarks

Note: With Matrix<T> batch processing, this strategy no longer maintains state between calls, so Reset() is a no-op. It's kept for compatibility with the trainer's OnEpochEnd() which calls Reset() on all strategies.