Class NTMAlgorithm<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
Implementation of Neural Turing Machine (NTM) for meta-learning.
public class NTMAlgorithm<T, TInput, TOutput> : MetaLearnerBase<T, TInput, TOutput>, IMetaLearner<T, TInput, TOutput>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
TInputThe input data type (e.g., Matrix<T>, Tensor<T>).
TOutputThe output data type (e.g., Vector<T>, Tensor<T>).
- Inheritance
-
MetaLearnerBase<T, TInput, TOutput>NTMAlgorithm<T, TInput, TOutput>
- Implements
-
IMetaLearner<T, TInput, TOutput>
- Inherited Members
Remarks
Neural Turing Machines augment neural networks with an external memory matrix and differentiable attention mechanisms for reading and writing. This enables algorithms to be learned and executed within the neural network itself.
For Beginners: NTM is like a neural computer with RAM:
How it works:
- Controller network processes inputs like a CPU
- Generates read/write keys for memory access
- Attention mechanism determines where to read/write
- External memory stores information persistently
- Differentiable operations allow end-to-end learning
Key difference from standard NN:
- Standard NN: Fixed computation graph
- NTM: Can learn to store and retrieve information dynamically
- Like giving a neural network a scratchpad to work with
Algorithm - Neural Turing Machine:
# Components
controller = LSTM() or MLP() # Processes inputs and outputs
memory = MemoryMatrix(N x M) # N locations, M dimensions each
read_heads = [ReadHead() x R] # R parallel read heads
write_head = WriteHead() # Single write head
# Forward pass
for each timestep t:
# Controller receives input and previous reads
controller_input = concatenate(x_t, read_contents_t-1)
controller_output = controller(controller_input)
# Generate read/write addressing
read_keys = controller.generate_read_keys(controller_output)
write_key = controller.generate_write_key(controller_output)
write_erase = controller.generate_erase_vector(controller_output)
write_add = controller.generate_add_vector(controller_output)
# Read from memory using attention
read_contents = []
for each read_head in read_heads:
weights = attention(read_head.key, memory)
content = weighted_sum(weights, memory)
read_contents.append(content)
# Write to memory
write_weights = attention(write_head.key, memory)
memory = memory * (1 - write_weights * write_erase)
memory = memory + write_weights * write_add
# Generate output
output = controller.generate_output(controller_output, read_contents)
Key Insights:
Differentiable Memory: Both reading and writing use differentiable attention, allowing the entire system to be trained with backpropagation.
Algorithmic Learning: NTM can learn to implement algorithms like sorting, copying, and associative recall directly from examples.
Variable Computation: The computation graph can change based on what's stored in memory, enabling dynamic reasoning.
Persistent State: Information can be stored across timesteps, enabling long-term memory and reasoning.
Production Features: - LSTM or MLP controllers - Multiple read/write heads - Content-based and location-based addressing - Memory initialization strategies - Memory usage monitoring - Differentiable memory operations
Constructors
NTMAlgorithm(NTMOptions<T, TInput, TOutput>)
Initializes a new instance of the NTMAlgorithm class.
public NTMAlgorithm(NTMOptions<T, TInput, TOutput> options)
Parameters
optionsNTMOptions<T, TInput, TOutput>The configuration options for NTM.
Remarks
For Beginners: This creates a Neural Turing Machine ready for meta-learning:
What NTM needs:
- controller: LSTM or MLP that controls the system
- memorySize: Size of external memory matrix
- memoryWidth: Dimension of each memory location
- numReadHeads: How many parallel read operations
- controllerType: LSTM for sequences, MLP for fixed-size inputs
What makes it special:
- Can learn algorithms (sorting, copying) from data
- Has external memory like RAM
- Memory operations are differentiable
- Can reason and plan using stored information
Exceptions
- ArgumentNullException
Thrown when options or required components are null.
- ArgumentException
Thrown when configuration validation fails.
Properties
AlgorithmType
Gets the type of meta-learning algorithm.
public override MetaLearningAlgorithmType AlgorithmType { get; }
Property Value
Methods
Adapt(IMetaLearningTask<T, TInput, TOutput>)
Adapts the model to a new task using its support set.
public override IModel<TInput, TOutput, ModelMetadata<T>> Adapt(IMetaLearningTask<T, TInput, TOutput> task)
Parameters
taskIMetaLearningTask<T, TInput, TOutput>The task to adapt to.
Returns
- IModel<TInput, TOutput, ModelMetadata<T>>
A new model instance adapted to the task.
Remarks
For Beginners: This is where the "quick learning" happens. Given a new task with just a few examples (the support set), this method creates a new model that's specialized for that specific task.
MetaTrain(TaskBatch<T, TInput, TOutput>)
Performs one meta-training step on a batch of tasks.
public override T MetaTrain(TaskBatch<T, TInput, TOutput> taskBatch)
Parameters
taskBatchTaskBatch<T, TInput, TOutput>The batch of tasks to train on.
Returns
- T
The meta-training loss for this batch.
Remarks
For Beginners: This method updates the model by training on multiple tasks at once. Each task teaches the model something about how to learn quickly. The returned loss value indicates how well the model is doing - lower is better.