Class NeuralTuringMachine<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Neural Turing Machine, which is a neural network architecture that combines a neural network with external memory.
public class NeuralTuringMachine<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
NeuralTuringMachine<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
A Neural Turing Machine (NTM) extends traditional neural networks by adding an external memory component that the network can read from and write to. This allows the network to store and retrieve information over long sequences, making it particularly effective for tasks requiring complex memory operations.
For Beginners: A Neural Turing Machine is like a neural network with a "notebook" that it can write to and read from.
Think of it like a student solving a math problem:
- The student (neural network) can process information directly
- But for complex problems, the student needs to write down intermediate steps in a notebook (external memory)
- The student can later refer back to these notes when needed
This memory capability helps the network:
- Remember information over long periods
- Store and retrieve specific pieces of data
- Learn more complex patterns that require step-by-step reasoning
For example, a standard neural network might struggle to add two long numbers, but an NTM can learn to write down partial results and carry digits, similar to how humans solve addition problems.
Constructors
NeuralTuringMachine(NeuralNetworkArchitecture<T>, int, int, int, ILossFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?)
Initializes a new instance of the NeuralTuringMachine<T> class.
public NeuralTuringMachine(NeuralNetworkArchitecture<T> architecture, int memorySize, int memoryVectorSize, int controllerSize, ILossFunction<T>? lossFunction = null, IActivationFunction<T>? contentAddressingActivation = null, IActivationFunction<T>? gateActivation = null, IActivationFunction<T>? outputActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture to use for the NTM.
memorySizeintThe number of memory locations (rows in the memory matrix).
memoryVectorSizeintThe size of each memory vector (columns in the memory matrix).
controllerSizeintThe size of the controller network that manages memory operations.
lossFunctionILossFunction<T>The loss function to use for training.
contentAddressingActivationIActivationFunction<T>The activation function to apply to content-based addressing. If null, softmax will be used.
gateActivationIActivationFunction<T>The activation function to apply to interpolation gates. If null, sigmoid will be used.
outputActivationIActivationFunction<T>The activation function to apply to the final output. If null, a default based on task type will be used.
NeuralTuringMachine(NeuralNetworkArchitecture<T>, int, int, int, ILossFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?)
Initializes a new instance of the NeuralTuringMachine<T> class.
public NeuralTuringMachine(NeuralNetworkArchitecture<T> architecture, int memorySize, int memoryVectorSize, int controllerSize, ILossFunction<T>? lossFunction = null, IVectorActivationFunction<T>? contentAddressingActivation = null, IVectorActivationFunction<T>? gateActivation = null, IVectorActivationFunction<T>? outputActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture to use for the NTM.
memorySizeintThe number of memory locations (rows in the memory matrix).
memoryVectorSizeintThe size of each memory vector (columns in the memory matrix).
controllerSizeintThe size of the controller network that manages memory operations.
lossFunctionILossFunction<T>The loss function to use for training.
contentAddressingActivationIVectorActivationFunction<T>The activation function to apply to content-based addressing. If null, softmax will be used.
gateActivationIVectorActivationFunction<T>The activation function to apply to interpolation gates. If null, sigmoid will be used.
outputActivationIVectorActivationFunction<T>The activation function to apply to the final output. If null, a default based on task type will be used.
Properties
AuxiliaryLossWeight
Gets or sets the weight for the memory usage auxiliary loss.
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
Remarks
This weight controls how much memory usage regularization contributes to the total loss. Typical values range from 0.001 to 0.01.
For Beginners: This controls how much we encourage focused memory access.
Common values:
- 0.005 (default): Balanced memory regularization
- 0.001-0.003: Light regularization
- 0.008-0.01: Strong regularization
Higher values encourage sharper, more focused memory usage.
ContentAddressingActivation
The activation function to apply to content-based addressing similarity scores.
public IActivationFunction<T>? ContentAddressingActivation { get; }
Property Value
ContentAddressingVectorActivation
The activation function to apply to content-based addressing similarity scores.
public IVectorActivationFunction<T>? ContentAddressingVectorActivation { get; }
Property Value
GateActivation
The activation function to apply to interpolation gates.
public IActivationFunction<T>? GateActivation { get; }
Property Value
GateVectorActivation
The activation function to apply to interpolation gates.
public IVectorActivationFunction<T>? GateVectorActivation { get; }
Property Value
OutputActivation
The activation function to apply to the final output.
public IActivationFunction<T>? OutputActivation { get; }
Property Value
OutputVectorActivation
The activation function to apply to the final output.
public IVectorActivationFunction<T>? OutputVectorActivation { get; }
Property Value
UseAuxiliaryLoss
Gets or sets whether auxiliary loss (memory usage regularization) should be used during training.
public bool UseAuxiliaryLoss { get; set; }
Property Value
Remarks
Memory usage regularization prevents memory addressing from becoming too diffuse or collapsing. This encourages the NTM to learn focused, interpretable memory access patterns.
For Beginners: This helps the NTM use its memory notebook effectively.
Memory usage regularization ensures:
- Read/write operations focus on relevant memory locations
- Memory access doesn't spread too thin
- Memory operations are interpretable and efficient
This is like encouraging a student to:
- Write clearly in specific sections of the notebook
- Not scribble all over every page
- Use the notebook in an organized, focused way
Methods
ComputeAuxiliaryLoss()
Computes the auxiliary loss for memory usage regularization.
public T ComputeAuxiliaryLoss()
Returns
- T
The computed memory usage auxiliary loss.
Remarks
This method computes entropy-based regularization for memory read/write addressing. It encourages focused, sharp memory access patterns while preventing diffuse addressing. Formula: L = -Σ H(addressing_weights) where H is entropy
For Beginners: This calculates how focused the NTM's memory usage is.
Memory usage regularization works by:
- Measuring entropy of read/write addressing weights
- Lower entropy means more focused, organized memory usage
- Higher entropy means scattered, disorganized access
- We minimize negative entropy to encourage focused access
This helps because:
- Focused memory access is more interpretable
- Sharp addressing improves efficiency
- Prevents wasting computation on irrelevant locations
- Encourages the NTM to use memory like an organized notebook
The auxiliary loss is added to the main task loss during training.
CreateNewInstance()
Creates a new instance of the neural turing machine model.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the neural turing machine model with the same configuration.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes NTM-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the memory usage auxiliary loss.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about memory usage regularization.
Remarks
This method returns detailed diagnostics about memory usage regularization, including addressing entropy, memory configuration, and regularization parameters. This information is useful for monitoring memory access patterns and debugging.
For Beginners: This provides information about how the NTM uses its memory.
The diagnostics include:
- Total memory usage loss (how focused memory access is)
- Weight applied to the regularization
- Memory size (number of memory locations)
- Memory vector size (size of each location)
- Whether memory usage regularization is enabled
This helps you:
- Monitor if memory addressing is focused or scattered
- Debug issues with memory access patterns
- Understand the impact of regularization on memory efficiency
You can use this information to adjust regularization weights for better memory utilization.
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Overrides GetDiagnostics() to include auxiliary loss diagnostics.
public Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including both base layer diagnostics and auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetModelMetadata()
Gets metadata about the Neural Turing Machine model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the NTM.
InitializeLayers()
Initializes the neural network layers based on the provided architecture.
protected override void InitializeLayers()
Predict(Tensor<T>)
Performs a forward pass through the Neural Turing Machine.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor after processing.
ResetState()
Resets the internal state of the neural network.
public override void ResetState()
Remarks
For Beginners: This clears the memory and attention weights, essentially making the network "forget" everything it has learned during sequence processing. It's useful when starting to process a new sequence that should not be influenced by previous sequences.
SerializeNetworkSpecificData(BinaryWriter)
Serializes NTM-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
SetTrainingMode(bool)
Sets the layer to training or evaluation mode.
public override void SetTrainingMode(bool isTraining)
Parameters
isTrainingboolTrue to set the layer to training mode, false for evaluation mode.
Train(Tensor<T>, Tensor<T>)
Trains the Neural Turing Machine on a single batch of input-output pairs.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor.
UpdateParameters(Vector<T>)
Updates the parameters of the neural network layers.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The vector of parameter updates to apply.