Class BalancedEpisodicDataLoader<T, TInput, TOutput>
Provides balanced episodic task sampling that ensures equal class representation across multiple tasks.
public class BalancedEpisodicDataLoader<T, TInput, TOutput> : EpisodicDataLoaderBase<T, TInput, TOutput>, IEpisodicDataLoader<T, TInput, TOutput>, IDataLoader<T>, IResettable, ICountable, IBatchIterable<MetaLearningTask<T, TInput, TOutput>>
Type Parameters
TThe numeric data type used for features and labels (e.g., float, double).
TInputTOutput
- Inheritance
-
EpisodicDataLoaderBase<T, TInput, TOutput>BalancedEpisodicDataLoader<T, TInput, TOutput>
- Implements
-
IEpisodicDataLoader<T, TInput, TOutput>IDataLoader<T>IBatchIterable<MetaLearningTask<T, TInput, TOutput>>
- Inherited Members
- Extension Methods
Examples
// Load a dataset with imbalanced class distribution
var features = new Matrix<double>(1000, 784);
var labels = new Vector<double>(1000); // 20 classes with varying frequencies
// Create balanced loader - all classes will be sampled equally over time
var loader = new BalancedEpisodicDataLoader<double>(
datasetX: features,
datasetY: labels,
nWay: 5,
kShot: 3,
queryShots: 10,
seed: 42
);
// Over 1000 episodes, each class will appear in roughly the same number of tasks
for (int episode = 0; episode < 1000; episode++)
{
var task = loader.GetNextTask();
// Train on balanced distribution of classes
}
Remarks
The BalancedEpisodicDataLoader extends the standard episodic loader by tracking class usage across multiple tasks and preferentially sampling under-represented classes. This ensures that over many episodes, all classes appear roughly the same number of times, preventing bias toward frequently-sampled classes.
For Beginners: Standard random sampling might pick some classes more often than others by chance. This can cause problems in meta-learning: - The model might learn some classes better than others - Training could be biased toward frequently-sampled classes - Evaluation metrics might be skewed
The balanced loader solves this by:
- Tracking how many times each class has been selected
- Preferring classes that haven't been used as much
- Ensuring fair representation across all classes over time
When to use this:
- Long meta-training runs where balanced class exposure matters
- When your dataset has many classes and you want uniform coverage
- When evaluating meta-learning algorithms fairly across all classes
Trade-off: Less random than uniform sampling, but more balanced. Good for training, but you might want standard EpisodicDataLoader for final evaluation to match real-world randomness.
Thread Safety: This class is not thread-safe due to internal state tracking. Create separate instances for concurrent task generation.
Performance: Slightly slower than standard EpisodicDataLoader due to usage tracking and weighted sampling, but still O(nWay × (kShot + queryShots)) for task creation.
Constructors
BalancedEpisodicDataLoader(Matrix<T>, Vector<T>, int, int, int, int?)
Initializes a new instance of the BalancedEpisodicDataLoader for balanced N-way K-shot task sampling.
public BalancedEpisodicDataLoader(Matrix<T> datasetX, Vector<T> datasetY, int nWay = 5, int kShot = 5, int queryShots = 15, int? seed = null)
Parameters
datasetXMatrix<T>The feature matrix where each row is an example. Shape: [num_examples, num_features].
datasetYVector<T>The label vector containing class labels for each example. Length: num_examples.
nWayintThe number of unique classes per task. Must be at least 2.
kShotintThe number of support examples per class. Must be at least 1.
queryShotsintThe number of query examples per class. Must be at least 1.
seedint?Optional random seed for reproducible task sampling. If null, uses a time-based seed.
Remarks
For Beginners: This constructor sets up balanced sampling by initializing usage tracking for all classes. Each class starts with a count of zero, and as tasks are generated, the loader keeps track of which classes have been used and prefers less-used ones.
The balancing happens automatically - you don't need to do anything special. Just call GetNextTask() repeatedly and the loader will ensure balanced class distribution over time.
Exceptions
- ArgumentNullException
Thrown when datasetX or datasetY is null.
- ArgumentException
Thrown when dimensions are invalid or dataset is too small.
Methods
GetNextTaskCore()
Core implementation of balanced N-way K-shot task sampling with weighted class selection.
protected override MetaLearningTask<T, TInput, TOutput> GetNextTaskCore()
Returns
- MetaLearningTask<T, TInput, TOutput>
A MetaLearningTask with balanced class sampling over time.
Remarks
This method extends the standard sampling algorithm with balanced selection: 1. Calculates selection weights: classes with lower usage get higher weights 2. Performs weighted random selection of N classes (favoring under-used classes) 3. For each selected class, randomly samples (K + queryShots) examples 4. Shuffles and splits into support and query sets 5. Updates usage counts for selected classes 6. Constructs and returns MetaLearningTask
For Beginners: The balancing works like this:
Imagine you have 10 classes. After 5 tasks:
- Classes 0, 2, 5 have been used 3 times each
- Classes 1, 4, 7 have been used 2 times each
- Classes 3, 6, 8, 9 have been used 1 time each
For the next task, the loader will heavily favor classes 3, 6, 8, 9 (used least), moderately favor classes 1, 4, 7, and avoid classes 0, 2, 5 (used most).
Over many tasks, this ensures all classes get approximately equal representation, leading to more balanced meta-learning training.