Class DuelingDQNAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.DuelingDQN
- Assembly
- AiDotNet.dll
Dueling Deep Q-Network agent for reinforcement learning.
public class DuelingDQNAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DuelingDQNAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
Dueling DQN separates the estimation of state value V(s) and action advantages A(s,a), allowing the network to learn which states are valuable without having to learn the effect of each action for each state. This architecture is particularly effective when many actions do not affect the state in a relevant way.
For Beginners: Dueling DQN splits Q-values into two parts: - **Value V(s)**: How good is this state overall? - **Advantage A(s,a)**: How much better is action 'a' compared to average? - **Q(s,a) = V(s) + (A(s,a) - mean(A(s,:)))**
This is powerful because:
- The agent learns state values even when actions don't matter much
- Faster learning in scenarios where action choice rarely matters
- Better generalization across similar states
Example: In a car driving game, being on the road is valuable regardless of whether you accelerate slightly or not. Dueling DQN learns "being on road = good" separately from "how much to accelerate".
Reference: Wang et al., "Dueling Network Architectures for Deep RL", 2016.
Constructors
DuelingDQNAgent(DuelingDQNOptions<T>)
public DuelingDQNAgent(DuelingDQNOptions<T> options)
Parameters
optionsDuelingDQNOptions<T>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.