Class ImportanceSampler<T>
A sampler that implements importance sampling for variance reduction.
public class ImportanceSampler<T> : DataSamplerBase, IDataSampler
Type Parameters
TThe numeric type for importance weights.
- Inheritance
-
ImportanceSampler<T>
- Implements
- Inherited Members
Remarks
ImportanceSampler samples data points based on their importance, typically computed from gradient norms, loss values, or uncertainty estimates. This can accelerate training by focusing on samples that contribute most to learning.
For Beginners: Not all training samples are equally useful. Importance sampling focuses training on the most informative samples:
- High gradient norm = Sample provides strong learning signal
- High loss = Model is uncertain, needs more training
- High uncertainty = Model needs to see this more
This can reduce training time by 2-3x compared to uniform sampling!
Example:
var sampler = new ImportanceSampler<float>(datasetSize: 1000);
// After each batch, update importance based on gradient norms
foreach (var (idx, gradNorm) in batch.Zip(gradientNorms))
{
sampler.UpdateImportance(idx, gradNorm);
}
Constructors
ImportanceSampler(int, double, bool, int?)
Initializes a new instance of the ImportanceSampler class.
public ImportanceSampler(int datasetSize, double smoothingFactor = 0.2, bool stabilize = true, int? seed = null)
Parameters
datasetSizeintThe total number of samples.
smoothingFactordoubleSmoothing factor to prevent extreme sampling (0.1-0.5 recommended).
stabilizeboolWhether to clip extreme importance values.
seedint?Optional random seed for reproducibility.
Fields
NumOps
Numeric operations for type T.
protected static readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Properties
ImportanceScores
Gets the importance scores for all samples.
public IReadOnlyList<T> ImportanceScores { get; }
Property Value
Length
Gets the total number of samples this sampler will produce per epoch.
public override int Length { get; }
Property Value
Remarks
This may differ from the dataset size for oversampling or undersampling strategies.
Methods
GetCorrectionFactor(int)
Gets the sampling weight correction factor for a sample.
public T GetCorrectionFactor(int index)
Parameters
indexintThe sample index.
Returns
- T
The correction factor to apply to gradients for unbiased estimation.
Remarks
When using importance sampling, gradients should be corrected by 1/p(i) where p(i) is the probability of sampling index i. This ensures unbiased gradient estimates despite non-uniform sampling.
GetIndicesCore()
Core implementation for generating indices. Override this in derived classes.
protected override IEnumerable<int> GetIndicesCore()
Returns
- IEnumerable<int>
An enumerable of sample indices.
GetIndicesWithoutReplacement(int)
Gets importance-weighted sample indices without replacement.
public IEnumerable<int> GetIndicesWithoutReplacement(int count)
Parameters
countintNumber of samples to draw.
Returns
- IEnumerable<int>
Sampled indices.
SetImportances(IReadOnlyList<T>)
Sets all importance scores at once.
public void SetImportances(IReadOnlyList<T> importances)
Parameters
importancesIReadOnlyList<T>Array of importance scores.
UpdateImportance(int, T)
Updates the importance score for a single sample.
public void UpdateImportance(int index, T importance)
Parameters
indexintThe sample index.
importanceTThe new importance score.
UpdateImportances(IReadOnlyList<int>, IReadOnlyList<T>)
Batch updates importance scores.
public void UpdateImportances(IReadOnlyList<int> indices, IReadOnlyList<T> importances)
Parameters
indicesIReadOnlyList<int>The sample indices.
importancesIReadOnlyList<T>The importance scores.