Class DeepReinforcementLearningAgentBase<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents
- Assembly
- AiDotNet.dll
Base class for deep reinforcement learning agents that use neural networks as function approximators.
public abstract class DeepReinforcementLearningAgentBase<T> : ReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
DeepReinforcementLearningAgentBase<T>
- Implements
-
IRLAgent<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
This class extends ReinforcementLearningAgentBase to provide specific support for neural network-based RL algorithms. It manages neural network instances and provides infrastructure for deep RL methods.
For Beginners: This is the base class for modern "deep" RL agents.
Deep RL uses neural networks to approximate the policy and/or value functions, enabling agents to handle high-dimensional state spaces (like images) and complex decision problems.
Classical RL methods (tabular Q-learning, linear approximation) inherit directly from ReinforcementLearningAgentBase, while deep RL methods (DQN, PPO, A3C, etc.) inherit from this class which adds neural network support.
Examples of deep RL algorithms:
- DQN family (DQN, Double DQN, Rainbow)
- Policy gradient methods (PPO, TRPO, A3C)
- Actor-Critic methods (SAC, TD3, DDPG)
- Model-based methods (Dreamer, MuZero, World Models)
- Transformer-based methods (Decision Transformer)
JIT Compilation Support: Deep RL agents support JIT compilation for policy inference when their underlying neural networks support IJitCompilable. The JIT-compiled policy network provides fast, deterministic action selection (without exploration) suitable for deployment.
Constructors
DeepReinforcementLearningAgentBase(ReinforcementLearningOptions<T>)
Initializes a new instance of the DeepReinforcementLearningAgentBase class.
protected DeepReinforcementLearningAgentBase(ReinforcementLearningOptions<T> options)
Parameters
optionsReinforcementLearningOptions<T>Configuration options for the agent.
Fields
Networks
The neural network(s) used by this agent for function approximation.
protected List<INeuralNetwork<T>> Networks
Field Value
- List<INeuralNetwork<T>>
Remarks
Deep RL agents typically use one or more neural networks: - Value-based: Q-network (and possibly target network) - Policy-based: Policy network - Actor-Critic: Separate policy and value networks - Model-based: Dynamics model, reward model, etc.
For Beginners: Neural networks are the "brains" of deep RL agents. They learn to map states to: - Action values (Q-networks in DQN) - Action probabilities (Policy networks in PPO) - State values (Value networks in A3C) - Or combinations of these
This list holds all the networks this agent uses. For example:
- DQN: 1-2 networks (Q-network, optional target network)
- A3C: 2 networks (policy network, value network)
- SAC: 4+ networks (policy, two Q-networks, two target Q-networks)
Properties
ParameterCount
Gets the total number of trainable parameters across all networks.
public override int ParameterCount { get; }
Property Value
Remarks
This sums the parameter counts from all neural networks used by the agent. Useful for monitoring model complexity and memory requirements.
SupportsJitCompilation
Gets whether this deep RL agent supports JIT compilation.
public override bool SupportsJitCompilation { get; }
Property Value
- bool
trueif the policy network supports JIT compilation;falseotherwise.
Remarks
Deep RL agents support JIT compilation when their policy network (the network used for action selection) implements IJitCompilable and reports SupportsJitCompilation = true.
JIT Compilation for RL Inference: When JIT compilation is supported, you can export the policy network's computation graph for optimized inference. This is particularly useful for:
- Deployment in production environments where inference speed matters
- Running agents on embedded devices or edge hardware
- Reducing latency in real-time control applications
Important: JIT compilation exports the deterministic policy (without exploration). This is appropriate for deployment but not for training where exploration is needed.
Methods
Dispose()
Disposes of resources used by the agent, including neural networks.
public override void Dispose()
ExportComputationGraph(List<ComputationNode<T>>)
Exports the policy network's computation graph for JIT compilation.
public override ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the policy network's output.
Remarks
Exports the policy network (the network used for action selection) as a JIT-compilable computation graph. This enables fast, optimized inference for deployment.
What Gets Exported:
- DQN: Q-network outputting Q-values for all actions
- PPO/A3C: Policy network outputting action probabilities
- SAC/TD3: Actor network outputting continuous actions
What Is NOT Exported:
- Exploration strategies (epsilon-greedy, noise injection)
- Value/critic networks (not needed for inference)
- Target networks (only used during training)
Usage Example:
// After training the agent
if (agent.SupportsJitCompilation)
{
var inputNodes = new List<ComputationNode<double>>();
var output = agent.ExportComputationGraph(inputNodes);
var jitCompiler = new JitCompiler();
var compiled = jitCompiler.Compile(output, inputNodes);
// Use for fast inference
var actions = compiled.Evaluate(state);
}
Exceptions
- NotSupportedException
Thrown when the policy network does not support JIT compilation.
GetPolicyNetworkForJit()
Gets the policy network used for action selection.
protected virtual IJitCompilable<T>? GetPolicyNetworkForJit()
Returns
- IJitCompilable<T>
The policy network, or null if no policy network is available.
Remarks
Override this method in derived classes to return the network responsible for action selection. This enables JIT compilation support for policy inference.
Examples:
- DQN: Returns the Q-network (actions selected via argmax Q(s,a))
- PPO/A3C: Returns the policy network (actor)
- SAC/TD3: Returns the policy network (actor)