Class LSTMNeuralNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Long Short-Term Memory (LSTM) Neural Network, which is specialized for processing sequential data like text, time series, or audio.
public class LSTMNeuralNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
LSTMNeuralNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Long Short-Term Memory networks are a special kind of recurrent neural network designed to overcome the vanishing gradient problem that traditional RNNs face. LSTMs have a complex internal structure with specialized "gates" that regulate the flow of information, allowing them to remember patterns over long sequences and selectively forget irrelevant information.
For Beginners: An LSTM Neural Network is a special type of neural network designed for understanding sequences and patterns that unfold over time.
Think of an LSTM like a smart notepad that can:
- Remember important information for long periods
- Forget irrelevant details
- Update its notes with new information
- Decide what parts of its memory to use for making predictions
For example, when processing a sentence like "The clouds are in the sky", an LSTM can:
- Remember "The clouds" as the subject even after seeing several more words
- Understand that "are" should agree with the plural "clouds"
- Predict that "sky" might come after "in the" because clouds are typically in the sky
LSTMs are particularly good at:
- Text generation and language modeling
- Speech recognition
- Time series prediction (like stock prices or weather)
- Translation between languages
- Any task where the order of data matters and patterns may span across long sequences
Constructors
LSTMNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?, IActivationFunction<T>?)
Creates a new LSTM Neural Network with customizable loss and activation functions, using scalar activation functions.
public LSTMNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IActivationFunction<T>? outputActivation = null, IActivationFunction<T>? forgetGateActivation = null, IActivationFunction<T>? inputGateActivation = null, IActivationFunction<T>? cellGateActivation = null, IActivationFunction<T>? outputGateActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture configuration that defines how the network is structured.
lossFunctionILossFunction<T>The loss function to use for training the network. If null, Mean Squared Error will be used.
outputActivationIActivationFunction<T>The scalar activation function to apply to LSTM cell state outputs. If null, a hyperbolic tangent (tanh) activation will be used.
forgetGateActivationIActivationFunction<T>The activation function to apply to the forget gate. If null, sigmoid will be used.
inputGateActivationIActivationFunction<T>The activation function to apply to the input gate. If null, sigmoid will be used.
cellGateActivationIActivationFunction<T>The activation function to apply to the cell gate. If null, tanh will be used.
outputGateActivationIActivationFunction<T>The activation function to apply to the output gate. If null, sigmoid will be used.
Remarks
This constructor allows full customization of the LSTM network's activation functions and loss function. Each gate in the LSTM cell can have a different activation function, allowing for experimentation with novel LSTM architectures.
For Beginners: This constructor gives you complete control over your LSTM network.
You can customize:
- The loss function (how the network measures errors)
- The output activation (how cell states are transformed to outputs)
- Each gate's activation function:
- Forget gate: Controls what information to discard from the cell state
- Input gate: Controls what new information to store in the cell state
- Cell gate: Creates candidate values that could be added to the state
- Output gate: Controls what parts of the cell state to output
This level of customization is useful for advanced users experimenting with different LSTM variants or optimizing for specific tasks.
LSTMNeuralNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?, IVectorActivationFunction<T>?)
Creates a new LSTM Neural Network with customizable loss and activation functions, using vector activation functions.
public LSTMNeuralNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, IVectorActivationFunction<T>? outputVectorActivation = null, IVectorActivationFunction<T>? forgetGateVectorActivation = null, IVectorActivationFunction<T>? inputGateVectorActivation = null, IVectorActivationFunction<T>? cellGateVectorActivation = null, IVectorActivationFunction<T>? outputGateVectorActivation = null)
Parameters
architectureNeuralNetworkArchitecture<T>The architecture configuration that defines how the network is structured.
lossFunctionILossFunction<T>The loss function to use for training the network. If null, Mean Squared Error will be used.
outputVectorActivationIVectorActivationFunction<T>The vector activation function to apply to LSTM cell state outputs. If null, a hyperbolic tangent (tanh) activation will be used.
forgetGateVectorActivationIVectorActivationFunction<T>The activation function to apply to the forget gate. If null, sigmoid will be used.
inputGateVectorActivationIVectorActivationFunction<T>The activation function to apply to the input gate. If null, sigmoid will be used.
cellGateVectorActivationIVectorActivationFunction<T>The activation function to apply to the cell gate. If null, tanh will be used.
outputGateVectorActivationIVectorActivationFunction<T>The activation function to apply to the output gate. If null, sigmoid will be used.
Remarks
This constructor allows full customization of the LSTM network's activation functions and loss function. Each gate in the LSTM cell can have a different activation function, allowing for experimentation with novel LSTM architectures.
For Beginners: This constructor gives you complete control over your LSTM network.
You can customize:
- The loss function (how the network measures errors)
- The output activation (how cell states are transformed to outputs)
- Each gate's activation function:
- Forget gate: Controls what information to discard from the cell state
- Input gate: Controls what new information to store in the cell state
- Cell gate: Creates candidate values that could be added to the state
- Output gate: Controls what parts of the cell state to output
This level of customization is useful for advanced users experimenting with different LSTM variants or optimizing for specific tasks.
Methods
CreateNewInstance()
Creates a new instance of the LSTM Neural Network with the same architecture and configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new LSTM Neural Network instance with the same architecture and configuration.
Remarks
This method creates a new instance of the LSTM Neural Network with the same architecture and activation functions as the current instance. It's used in scenarios where a fresh copy of the model is needed while maintaining the same configuration.
For Beginners: This method creates a brand new copy of the LSTM network with the same setup.
Think of it like creating a clone of the network:
- The new network has the same architecture (structure)
- It has the same activation functions for all gates
- It uses the same loss function
- But it's a completely separate instance with its own parameters
This is useful when you want to:
- Create multiple networks with identical settings
- Compare how different initializations affect learning
- Set up ensemble learning with multiple similar networks
DeserializeNetworkSpecificData(BinaryReader)
Deserializes LSTM-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This method loads the state of a previously saved LSTM model from a binary stream. It restores the network-specific parameters, including layer configurations and internal state trackers, allowing the model to continue from exactly where it left off.
For Beginners: This loads a saved LSTM network from a file.
When loading the LSTM model:
- All the network's learned parameters are restored
- Layer structure and configuration are restored
- Any internal state information is restored
This lets you:
- Continue working with a model you trained earlier
- Use models that someone else has trained
- Apply a trained model to new data for predictions
GetModelMetadata()
Gets metadata about the LSTM model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the LSTM model.
Remarks
This method returns comprehensive metadata about the LSTM model, including its architecture, layer configuration, and other relevant parameters. This information is useful for model management, tracking experiments, and reporting.
For Beginners: This provides detailed information about your LSTM network.
The metadata includes:
- What this model is designed to do
- Details about the network architecture
- Information about the layers and their sizes
- The total number of parameters (learnable values)
This information is useful for:
- Documentation
- Comparing different LSTM configurations
- Debugging and analysis
- Sharing your model with others
InitializeLayers()
Sets up the layers of the LSTM network based on the provided architecture. If no layers are specified in the architecture, default LSTM layers will be created.
protected override void InitializeLayers()
Remarks
This method initializes the network layers according to the provided architecture. If the architecture includes a specific set of layers, those are used directly. Otherwise, the method creates a default LSTM layer configuration, which typically includes embeddings (for text data), one or more LSTM layers, and appropriate output layers based on the task type.
For Beginners: This method sets up the building blocks of your LSTM network.
When initializing the network:
- If you provided specific layers in the architecture, those are used
- If not, the network creates standard LSTM layers automatically
The standard LSTM setup typically includes:
- Input processing layers (like embedding layers for text)
- One or more LSTM layers that process the sequence
- Output layers that produce the final prediction
This is like assembling all the components of your network before training begins. Each layer has a specific role in processing your sequential data.
Predict(Tensor<T>)
Processes input through the LSTM network to generate predictions.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor to process.
Returns
- Tensor<T>
The output tensor after processing through the LSTM network.
Remarks
This method implements a forward pass through the LSTM network. It handles both single inputs and batched sequences, processing the data through each layer while managing the LSTM's internal state. For sequential data, it processes the input step by step while carrying the hidden state across time steps.
For Beginners: This method processes your data through the LSTM network to make predictions.
The prediction process works like this:
- Data enters the network (like a sequence of words or time series data)
- Each LSTM layer processes the sequence step by step, maintaining internal state
- The network remembers important information while processing the sequence
- Finally, the output layers convert the LSTM's final state into the desired output format
Unlike standard neural networks, LSTMs can "remember" information from earlier in the sequence when making predictions later in the sequence, which is crucial for tasks like text understanding or time series forecasting.
SerializeNetworkSpecificData(BinaryWriter)
Serializes LSTM-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This method saves the state of the LSTM model to a binary stream. It serializes network-specific parameters, including layer configurations and internal state trackers. This allows the complete state to be restored later for continued training or inference.
For Beginners: This saves your LSTM network to a file.
When saving the LSTM model:
- All the network's learned parameters are saved
- Layer structure and configuration are saved
- Any internal state information is saved
This allows you to:
- Save your progress after training
- Share trained models with others
- Load the model later for additional training or making predictions
Train(Tensor<T>, Tensor<T>)
Trains the LSTM network on input-output pairs.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input tensor for training.
expectedOutputTensor<T>The expected output tensor.
Remarks
This method trains the LSTM network using backpropagation through time (BPTT). It performs a forward pass to get predictions, calculates the error, and backpropagates the gradients through the network over time to update the weights.
For Beginners: This method teaches the LSTM network to make accurate predictions.
The training process works like this:
- Input data (like sequences of words or time steps) is fed into the network
- The network makes predictions at each time step
- These predictions are compared with the expected outputs to calculate the error
- The error is "backpropagated" through time, adjusting the network's internal values
- This process repeats for many examples, gradually improving the network's performance
The key difference from training regular neural networks is that LSTM training needs to account for connections across time steps, as earlier inputs influence later outputs.
UpdateParameters(Vector<T>)
Updates the internal parameters (weights and biases) of the network with new values. This is typically used after training to apply optimized parameters.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing all parameters for all layers in the network. The parameters must be in the correct order matching the network's layer structure.
Remarks
This method distributes a vector of parameters to the appropriate layers in the network. It determines how many parameters each layer needs, extracts the corresponding segment from the input parameter vector, and updates each layer with its respective parameters. This is commonly used after optimization algorithms have calculated improved weights for the network.
For Beginners: This method updates all the learned values in the network.
During training, an LSTM network learns many values (called parameters) that determine how it processes information. These include:
- Weights that control how inputs affect the network
- Gate parameters that control what information to remember or forget
- Output parameters that determine how predictions are made
This method:
- Takes a long list of all these parameters
- Figures out which parameters belong to which layers
- Updates each layer with its corresponding parameters
Think of it like updating the settings on different parts of a machine based on what it has learned through experience.