Class MemoryBank<T>
- Namespace
- AiDotNet.SelfSupervisedLearning
- Assembly
- AiDotNet.dll
FIFO memory queue for storing embeddings in contrastive learning.
public class MemoryBank<T> : IMemoryBank<T>
Type Parameters
TThe numeric type used for computations (typically float or double).
- Inheritance
-
MemoryBank<T>
- Implements
-
IMemoryBank<T>
- Inherited Members
Remarks
For Beginners: A memory bank is a queue that stores embeddings from previous batches. This provides a large pool of negative samples for contrastive learning without requiring huge batch sizes.
How it works:
- New embeddings are added to the end of the queue (Enqueue)
- When the queue is full, oldest embeddings are removed (FIFO)
- All stored embeddings serve as negative samples for contrastive loss
Example usage:
// Create memory bank with 65536 entries
var memoryBank = new MemoryBank<float>(capacity: 65536, embeddingDim: 128);
// Training loop:
var negatives = memoryBank.GetAll(); // Get negative samples
var loss = ComputeContrastiveLoss(queries, keys, negatives);
memoryBank.Enqueue(momentumEncoderOutput); // Add new embeddings
Constructors
MemoryBank(int, int, int?)
Initializes a new instance of the MemoryBank class.
public MemoryBank(int capacity, int embeddingDim, int? seed = null)
Parameters
capacityintMaximum number of embeddings to store (e.g., 65536).
embeddingDimintDimension of each embedding (e.g., 128).
seedint?Optional random seed for sampling.
Properties
Capacity
Gets the maximum capacity of the memory bank (queue size).
public int Capacity { get; }
Property Value
Remarks
Typical values: 4096-65536. MoCo uses 65536 by default.
CurrentSize
Gets the current number of stored embeddings.
public int CurrentSize { get; }
Property Value
EmbeddingDimension
Gets the embedding dimension of stored vectors.
public int EmbeddingDimension { get; }
Property Value
IsFull
Gets whether the memory bank is full (has reached capacity).
public bool IsFull { get; }
Property Value
Methods
Clear()
Clears all stored embeddings and resets the memory bank.
public void Clear()
Enqueue(Tensor<T>)
Adds new embeddings to the memory bank (FIFO queue).
public void Enqueue(Tensor<T> embeddings)
Parameters
embeddingsTensor<T>The embeddings to add (batch of vectors).
Remarks
For Beginners: New embeddings are added to the end of the queue. When the queue is full, the oldest embeddings are removed (first-in, first-out).
GetAll()
Gets all stored embeddings for use as negative samples.
public Tensor<T> GetAll()
Returns
- Tensor<T>
A tensor containing all stored embeddings [CurrentSize, EmbeddingDimension].
Remarks
For Beginners: These embeddings serve as negative samples in contrastive loss. The more negatives you have, the harder and more informative the contrastive task becomes.
GetAt(int)
Gets the embedding at a specific index.
public Tensor<T> GetAt(int index)
Parameters
indexintThe index of the embedding to retrieve.
Returns
- Tensor<T>
The embedding tensor [1, embeddingDim].
Sample(int)
Gets a random subset of stored embeddings.
public Tensor<T> Sample(int count)
Parameters
countintThe number of embeddings to retrieve.
Returns
- Tensor<T>
A tensor of randomly sampled embeddings [count, EmbeddingDimension].
Remarks
Useful when you want fewer negatives than the full memory bank.
SetAt(int, Tensor<T>)
Sets the embedding at a specific index.
public void SetAt(int index, Tensor<T> embedding)
Parameters
indexintThe index to set.
embeddingTensor<T>The embedding tensor [1, embeddingDim] or [embeddingDim].
UpdateWithMomentum(int[], Tensor<T>, double)
Updates embeddings by averaging with new values (for soft updates).
public void UpdateWithMomentum(int[] indices, Tensor<T> newEmbeddings, double momentum)
Parameters
indicesint[]Indices of embeddings to update.
newEmbeddingsTensor<T>New embedding values.
momentumdoubleMomentum for exponential moving average (0-1).
Remarks
Some memory bank variants use soft updates: new = momentum * old + (1 - momentum) * new