Class DeepQNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Deep Q-Network (DQN), a reinforcement learning algorithm that combines Q-learning with deep neural networks.
public class DeepQNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
DeepQNetwork<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
A Deep Q-Network (DQN) is a reinforcement learning algorithm that uses a neural network to approximate the Q-function, which represents the expected future rewards for taking specific actions in specific states. DQNs overcome the limitations of traditional Q-learning by using neural networks to generalize across states, allowing them to handle problems with large or continuous state spaces. Key features of DQNs include experience replay (storing and randomly sampling past experiences) and the use of a separate target network to stabilize learning.
For Beginners: A Deep Q-Network is like a smart decision-maker that learns through trial and error.
Imagine you're teaching a robot to play a video game:
- The robot needs to learn which actions (button presses) are best in each situation (game screen)
- At first, the robot makes many random moves to explore the game
- Over time, it remembers which moves led to high scores and which led to game over
- The "Deep" part means it uses a neural network to recognize patterns in complex situations
- The "Q" refers to "Quality" - how good an action is in a specific situation
For example, in a maze game, the network learns that moving toward the exit is usually better than moving away from it, even if it hasn't seen that exact maze position before.
Constructors
DeepQNetwork(NeuralNetworkArchitecture<T>, ILossFunction<T>?, double)
Initializes a new instance of the DeepQNetwork<T> class with the specified architecture and exploration rate.
public DeepQNetwork(NeuralNetworkArchitecture<T> architecture, ILossFunction<T>? lossFunction = null, double epsilon = 1)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture configuration.
lossFunctionILossFunction<T>epsilondoubleThe initial exploration rate (probability of taking random actions). Default is 1.0 for full exploration.
Remarks
This constructor creates a new Deep Q-Network with the specified architecture and exploration rate. It also initializes a separate target network with the same architecture, which is used to generate target Q-values during training.
For Beginners: This sets up a new Deep Q-Network ready to start learning.
When creating a new DQN:
- You provide an "architecture" that defines the neural network's structure
- You can set an "epsilon" value that controls how often it will try random actions
- The constructor also creates a copy of the network (target network) that helps with stable learning
Think of it like setting up a new student with blank notebooks (the neural networks) and a curiosity level (epsilon) that determines how often they'll experiment versus stick with what they know.
Methods
AddExperience(Tensor<T>, int, T, Tensor<T>, bool)
Adds a new experience to the replay buffer.
public void AddExperience(Tensor<T> state, int action, T reward, Tensor<T> nextState, bool done)
Parameters
stateTensor<T>The state before taking the action.
actionintThe action taken.
rewardTThe reward received after taking the action.
nextStateTensor<T>The state after taking the action.
doneboolA flag indicating whether the episode ended after this action.
Remarks
This method adds a new experience tuple (state, action, reward, next state, done) to the replay buffer for use in experience replay during training. If the buffer exceeds its maximum size (10,000 experiences), the oldest experiences are removed to make room for new ones.
For Beginners: This method stores a new experience in the agent's memory.
Each experience includes:
- State: What the environment looked like before the action
- Action: What the agent decided to do
- Reward: The immediate feedback received (positive or negative)
- Next State: What the environment looked like after the action
- Done: Whether this action ended the episode (like finishing a game level)
The agent keeps a limited memory of past experiences (10,000 in this case) and removes the oldest ones when necessary, like a scrolling journal that only keeps the most recent entries.
CreateNewInstance()
Creates a new instance of the Deep Q-Network with the same architecture and configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new Deep Q-Network instance with the same architecture and configuration.
Remarks
This method creates a new instance of the Deep Q-Network with the same architecture and exploration rate (epsilon) as the current instance. It's used in scenarios where a fresh copy of the model is needed while maintaining the same configuration.
For Beginners: This method creates a brand new copy of the agent with the same setup.
Think of it like creating a clone of the agent:
- The new agent has the same neural network architecture
- The new agent has the same exploration rate (epsilon)
- But it's a completely separate instance with its own memory and learning state
This is useful when you need multiple instances of the same DQN model, such as for parallel training or comparing different learning strategies.
DeserializeNetworkSpecificData(BinaryReader)
Loads Deep Q-Network specific data from a binary stream.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to load from.
Remarks
This method deserializes DQN-specific data that was previously saved using SerializeNetworkSpecificData. It restores the exploration rate, action space size, and target network state.
For Beginners: This method loads special DQN settings from a file.
When loading a saved model:
- The base neural network parts are loaded by other methods
- This method loads the DQN-specific settings
- It restores values like the exploration rate
- It also loads the target network that helps stabilize learning
This allows you to continue using a previously trained DQN with all its settings intact.
GetAction(Tensor<T>)
Gets an action to take in the given state, balancing exploration and exploitation.
public int GetAction(Tensor<T> state)
Parameters
stateTensor<T>The input vector representing the current state.
Returns
- int
The index of the selected action.
Remarks
This method implements the epsilon-greedy policy for action selection. With probability epsilon, it selects a random action (exploration), and with probability 1-epsilon, it selects the action with the highest Q-value (exploitation). This balance between exploration and exploitation is crucial for effective reinforcement learning.
For Beginners: This method decides whether to try a random action or use what the agent has learned.
The process works like this:
- The agent generates a random number between 0 and 1
- If the number is less than epsilon, it takes a completely random action (exploration)
- Otherwise, it takes the action it currently thinks is best (exploitation)
This balance is important because:
- If the agent only exploits, it might miss better strategies it hasn't discovered yet
- If the agent only explores, it never uses what it has learned
For example, with epsilon = 0.1:
- 10% of the time, the agent will try a random action
- 90% of the time, it will choose the action with the highest Q-value
GetBestAction(Tensor<T>)
Gets the best action to take in the given state based on current Q-values.
public int GetBestAction(Tensor<T> state)
Parameters
stateTensor<T>The input vector representing the current state.
Returns
- int
The index of the action with the highest Q-value.
Remarks
This method determines the optimal action to take in the given state by selecting the action with the highest predicted Q-value. This represents the action that the network currently believes will lead to the highest expected future reward.
For Beginners: This method finds the action that the agent thinks is best in the current situation.
The process works like this:
- The agent gets Q-values (scores) for each possible action
- It then selects the action with the highest score
- This is called "exploitation" - using what the agent has learned so far
For example, if moving left has a Q-value of 5 and moving right has a Q-value of 10, the agent will choose to move right because it has the higher expected reward.
GetModelMetadata()
Gets metadata about this Deep Q-Network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
Remarks
This method returns metadata about the model, including its name, description, architecture, and other relevant information that might be useful for users or tools working with the model.
For Beginners: This method provides information about this neural network model.
The metadata includes:
- The type of model (Deep Q-Network)
- The network architecture (how many layers, neurons, etc.)
- The action space (how many different actions the agent can take)
- Other settings like the exploration rate
This information is useful for documentation, debugging, and when saving/loading models.
GetQValues(Tensor<T>)
Gets the Q-values for all possible actions in the given state.
public Tensor<T> GetQValues(Tensor<T> state)
Parameters
stateTensor<T>The input tensor representing the current state.
Returns
- Tensor<T>
A tensor of Q-values, one for each possible action.
Remarks
This method is a wrapper around the Predict method that makes it more semantically clear that the output represents Q-values for each possible action in the given state.
For Beginners: This method tells you how good the agent thinks each action is in the current situation.
The Q-values represent:
- The agent's estimate of how much future reward it will get
- If it takes a specific action in the current state
- And then continues to act optimally afterward
For example, in a game, a Q-value of 100 for "move right" means the agent expects that moving right now will eventually lead to a total reward of about 100 points.
InitializeLayers()
Initializes the layers of the Deep Q-Network based on the architecture.
protected override void InitializeLayers()
Remarks
This method sets up the layers of the Deep Q-Network. If custom layers are provided in the architecture, those layers are used. Otherwise, default layers are created based on the architecture's specifications. After setting up the layers, the method sets the action space based on the output size of the network.
For Beginners: This method builds the actual structure of the neural network.
When initializing the layers:
- If you've specified your own custom layers, the network will use those
- If not, the network will create a standard set of layers suitable for a DQN
- The method also determines how many different actions the agent can take
This is like assembling the brain of our agent, which will learn to make decisions by associating situations with the best actions to take.
Predict(Tensor<T>)
Performs a forward pass with a tensor input.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor representing the current state.
Returns
- Tensor<T>
A tensor of Q-values for each possible action.
Remarks
This overload of the Predict method handles tensor inputs directly. It processes the input through all layers of the network and returns a tensor of Q-values.
For Beginners: This method does the same thing as the vector-based Predict, but works with tensors (multi-dimensional arrays) directly, which can be more efficient for certain types of inputs like images.
SerializeNetworkSpecificData(BinaryWriter)
Saves Deep Q-Network specific data to a binary stream.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to save to.
Remarks
This method serializes DQN-specific data that isn't part of the base neural network. This includes the exploration rate (epsilon), action space size, and target network state.
For Beginners: This method saves special DQN settings to a file.
When saving the model:
- The base neural network parts are saved by other methods
- This method saves the DQN-specific settings
- This includes values like how often the agent explores
- It also saves the target network that helps stabilize learning
This allows you to load the exact same DQN later, with all its settings intact.
Train(Tensor<T>, Tensor<T>)
Trains the network using experience replay.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>Not used in DQN; experiences are sampled from the replay buffer.
expectedOutputTensor<T>Not used in DQN; target Q-values are computed internally.
Remarks
This method implements the core DQN training algorithm using experience replay. It samples a batch of experiences from the replay buffer, computes target Q-values using the target network, and updates the main network to minimize the difference between predicted and target Q-values.
For Beginners: This method is how the agent learns from its past experiences.
The training process works like this:
- Randomly select a batch of experiences from memory
- For each experience, calculate what the Q-values should have been:
- For terminal states (game over), the target is just the immediate reward
- For non-terminal states, the target is the immediate reward plus the discounted maximum future Q-value
- Update the network to better predict these target values
- Periodically update the target network to match the main network
This helps the agent gradually improve its ability to predict which actions will lead to higher rewards.
UpdateParameters(Vector<T>)
Updates the parameters of all layers in the Deep Q-Network.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing the parameters to update all layers with.
Remarks
This method distributes the provided parameter vector among all the layers in the network. Each layer receives a portion of the parameter vector corresponding to its number of parameters. The method keeps track of the starting index for each layer's parameters in the input vector.
For Beginners: This method updates all the internal values of the neural network at once.
When updating parameters:
- The input is a long list of numbers representing all values in the entire network
- The method divides this list into smaller chunks
- Each layer gets its own chunk of values
- The layers use these values to adjust their internal settings
This method is less commonly used in DQN than the standard training process, but it provides a way to directly set all parameters at once, which can be useful in certain scenarios like loading pretrained weights or implementing advanced optimization algorithms.