Class LSTMNTMController<T, TInput, TOutput>
- Namespace
- AiDotNet.MetaLearning.Algorithms
- Assembly
- AiDotNet.dll
LSTM-based NTM controller implementation with learnable parameters.
public class LSTMNTMController<T, TInput, TOutput> : INTMController<T>
Type Parameters
TThe numeric type.
TInputThe input data type.
TOutputThe output data type.
- Inheritance
-
LSTMNTMController<T, TInput, TOutput>
- Implements
- Inherited Members
Remarks
This controller uses an LSTM cell to process inputs and generate addressing parameters for the NTM memory. The LSTM maintains hidden state across timesteps, enabling sequential reasoning and temporal dependencies.
Architecture:
Input (inputSize + numReadHeads * memoryWidth)
↓
LSTM Cell (hiddenSize)
↓
Linear projections → ReadKeys, WriteKey, Erase, Add, Output
Constructors
LSTMNTMController(NTMOptions<T, TInput, TOutput>)
Initializes a new instance of LSTMNTMController with learnable weights.
public LSTMNTMController(NTMOptions<T, TInput, TOutput> options)
Parameters
optionsNTMOptions<T, TInput, TOutput>The NTM options.
Methods
Forward(Tensor<T>, List<Tensor<T>>)
Forward pass through the controller.
public Tensor<T> Forward(Tensor<T> input, List<Tensor<T>> readContents)
Parameters
inputTensor<T>The input tensor.
readContentsList<Tensor<T>>The previous read contents.
Returns
- Tensor<T>
The controller output.
GenerateAddVector(Tensor<T>)
Generates add vector for writing.
public Tensor<T> GenerateAddVector(Tensor<T> output)
Parameters
outputTensor<T>The controller output.
Returns
- Tensor<T>
The add vector.
GenerateEraseVector(Tensor<T>)
Generates erase vector for writing.
public Tensor<T> GenerateEraseVector(Tensor<T> output)
Parameters
outputTensor<T>The controller output.
Returns
- Tensor<T>
The erase vector.
GenerateOutput(Tensor<T>, List<Tensor<T>>)
Generates final output.
public Tensor<T> GenerateOutput(Tensor<T> output, List<Tensor<T>> readContents)
Parameters
outputTensor<T>The controller output.
readContentsList<Tensor<T>>The current read contents.
Returns
- Tensor<T>
The final output.
GenerateReadKeys(Tensor<T>)
Generates read keys for all read heads.
public List<Tensor<T>> GenerateReadKeys(Tensor<T> output)
Parameters
outputTensor<T>The controller output.
Returns
- List<Tensor<T>>
List of read keys.
GenerateWriteKey(Tensor<T>)
Generates write key.
public Tensor<T> GenerateWriteKey(Tensor<T> output)
Parameters
outputTensor<T>The controller output.
Returns
- Tensor<T>
The write key.
GetParameters()
Gets controller parameters.
public Vector<T> GetParameters()
Returns
- Vector<T>
The parameter vector.
Reset()
Resets controller state.
public void Reset()
SetParameters(Vector<T>)
Sets controller parameters (updates internal weights).
public void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to set.