Class SSLMetrics<T>
- Namespace
- AiDotNet.SelfSupervisedLearning.Evaluation
- Assembly
- AiDotNet.dll
Metrics for monitoring and evaluating self-supervised learning.
public class SSLMetrics<T>
Type Parameters
TThe numeric type used for computations.
- Inheritance
-
SSLMetrics<T>
- Inherited Members
Remarks
For Beginners: These metrics help track the quality of SSL training and detect potential issues like representation collapse. Monitoring these during training helps ensure the model is learning useful representations.
Key metrics:
- Representation collapse: All embeddings become identical (very bad)
- Feature std: Standard deviation of features (should be positive)
- Alignment: Similarity between positive pairs (should be high)
- Uniformity: Distribution of features on hypersphere (should be uniform)
Methods
ComputeAlignment(Tensor<T>, Tensor<T>)
Computes alignment loss between positive pairs.
public T ComputeAlignment(Tensor<T> z1, Tensor<T> z2)
Parameters
z1Tensor<T>First view representations [batch_size, dim].
z2Tensor<T>Second view representations [batch_size, dim].
Returns
- T
Average squared distance between positive pairs.
Remarks
Lower alignment loss means positive pairs are closer together in representation space, which is desirable.
ComputeCosineSimilarity(Tensor<T>, Tensor<T>)
Computes cosine similarity between corresponding pairs.
public T ComputeCosineSimilarity(Tensor<T> z1, Tensor<T> z2)
Parameters
z1Tensor<T>z2Tensor<T>
Returns
- T
ComputeEffectiveRank(Tensor<T>)
Computes the effective rank of the representation matrix.
public T ComputeEffectiveRank(Tensor<T> representations)
Parameters
representationsTensor<T>Representations [batch_size, dim].
Returns
- T
Effective rank (normalized entropy of singular values).
Remarks
Effective rank measures the "dimensionality" of the learned representations. Collapsed representations have low effective rank. Good representations should use many dimensions effectively.
ComputeFullReport(Tensor<T>, Tensor<T>)
Computes a full set of SSL metrics.
public SSLMetricReport<T> ComputeFullReport(Tensor<T> z1, Tensor<T> z2)
Parameters
z1Tensor<T>z2Tensor<T>
Returns
ComputeRepresentationStd(Tensor<T>)
Computes the standard deviation of representations (collapse detection).
public T ComputeRepresentationStd(Tensor<T> representations)
Parameters
representationsTensor<T>Representations [batch_size, dim].
Returns
- T
Standard deviation per dimension, averaged.
Remarks
A low standard deviation indicates potential collapse - all representations are becoming similar. Good representations should have reasonable variance.
ComputeUniformity(Tensor<T>, double)
Computes uniformity loss (how uniformly distributed embeddings are).
public T ComputeUniformity(Tensor<T> representations, double t = 2)
Parameters
representationsTensor<T>Representations [batch_size, dim].
tdoubleTemperature parameter (default: 2).
Returns
- T
Uniformity metric (lower is more uniform).
Remarks
Measures how uniformly representations are distributed on the hypersphere. More uniform distributions indicate better representations that capture diverse information.
DetectCollapse(Tensor<T>, double)
Detects if representations are collapsing.
public bool DetectCollapse(Tensor<T> representations, double threshold = 0.01)
Parameters
representationsTensor<T>Representations to check.
thresholddoubleThreshold for collapse detection (default: 0.01).
Returns
- bool
True if collapse is detected.