Class ReinforcementLearningAgentBase<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents
- Assembly
- AiDotNet.dll
Base class for all reinforcement learning agents, providing common functionality and structure.
public abstract class ReinforcementLearningAgentBase<T> : IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations (typically float or double).
- Inheritance
-
ReinforcementLearningAgentBase<T>
- Implements
-
IRLAgent<T>
- Derived
- Inherited Members
- Extension Methods
Remarks
This abstract base class defines the core structure that all RL agents must follow, ensuring consistency across different RL algorithms while allowing for specialized implementations. It integrates deeply with AiDotNet's existing architecture, using Vector, Matrix, and Tensor types, and following established patterns like OptimizerBase and NeuralNetworkBase.
For Beginners: This is the foundation for all RL agents in AiDotNet.
Think of this base class as the blueprint that defines what every RL agent must be able to do:
- Select actions based on observations
- Store experiences for learning
- Train/update from experiences
- Save and load trained models
- Integrate with AiDotNet's neural networks and optimizers
All specific RL algorithms (DQN, PPO, SAC, etc.) inherit from this base and implement their own unique learning logic while sharing common functionality.
Constructors
ReinforcementLearningAgentBase(ReinforcementLearningOptions<T>)
Initializes a new instance of the ReinforcementLearningAgentBase class.
protected ReinforcementLearningAgentBase(ReinforcementLearningOptions<T> options)
Parameters
optionsReinforcementLearningOptions<T>Configuration options for the agent.
Fields
DiscountFactor
Discount factor (gamma) for future rewards.
protected T DiscountFactor
Field Value
- T
Episodes
Number of episodes completed.
protected int Episodes
Field Value
LearningRate
Learning rate for gradient updates.
protected T LearningRate
Field Value
- T
LossFunction
Loss function used for training.
protected readonly ILossFunction<T> LossFunction
Field Value
LossHistory
History of losses during training.
protected readonly List<T> LossHistory
Field Value
- List<T>
NumOps
Numeric operations provider for type T.
protected readonly INumericOperations<T> NumOps
Field Value
- INumericOperations<T>
Options
Configuration options for this agent.
protected readonly ReinforcementLearningOptions<T> Options
Field Value
Random
Random number generator for stochastic operations.
protected readonly Random Random
Field Value
RewardHistory
History of episode rewards.
protected readonly List<T> RewardHistory
Field Value
- List<T>
TrainingSteps
Number of training steps completed.
protected int TrainingSteps
Field Value
Properties
DefaultLossFunction
Gets the default loss function for this agent.
public virtual ILossFunction<T> DefaultLossFunction { get; }
Property Value
FeatureCount
Gets the number of input features (state dimensions).
public abstract int FeatureCount { get; }
Property Value
FeatureNames
Gets the names of input features.
public virtual string[] FeatureNames { get; }
Property Value
- string[]
ParameterCount
Gets the number of parameters in the agent.
public abstract int ParameterCount { get; }
Property Value
Remarks
Deep RL agents return parameter counts from neural networks. Classical RL agents (tabular, linear) may have different implementations.
SupportsJitCompilation
Gets whether this RL agent supports JIT compilation.
public virtual bool SupportsJitCompilation { get; }
Property Value
- bool
False for the base class. Derived classes may override to return true if they support JIT compilation.
Remarks
Most RL agents do not directly support JIT compilation because: - They use layer-based neural networks without direct computation graph export - Tabular methods use lookup tables rather than mathematical operations - Policy selection often involves dynamic branching based on exploration strategies
Deep RL agents that use neural networks (DQN, PPO, SAC, etc.) may override this to delegate JIT compilation to their underlying policy or value networks if those networks support computation graph export.
For Beginners: JIT compilation speeds up models by converting them to optimized code.
RL agents typically don't support JIT compilation directly because:
- They combine multiple networks (policy, value, target networks)
- They use exploration strategies with random decisions
- The action selection process is complex and dynamic
However, the underlying neural networks used by deep RL agents (like the Q-network in DQN) can potentially be JIT compiled separately for faster inference.
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public abstract void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public abstract IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeAverage(IEnumerable<T>)
Computes the average of a collection of values.
protected T ComputeAverage(IEnumerable<T> values)
Parameters
valuesIEnumerable<T>
Returns
- T
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public abstract Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
DeepCopy()
Creates a deep copy of the agent.
public virtual IFullModel<T, Vector<T>, Vector<T>> DeepCopy()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
Deserialize(byte[])
Deserializes the agent from bytes.
public abstract void Deserialize(byte[] data)
Parameters
databyte[]
Dispose()
Disposes of resources used by the agent.
public virtual void Dispose()
ExportComputationGraph(List<ComputationNode<T>>)
Exports the agent's computation graph for JIT compilation.
public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)
Parameters
inputNodesList<ComputationNode<T>>List to populate with input computation nodes.
Returns
- ComputationNode<T>
The output computation node representing the agent's prediction.
Remarks
The base RL agent class does not support JIT compilation because RL agents are complex systems that combine multiple components: - Policy networks (select actions) - Value networks (estimate state/action values) - Target networks (provide stable training targets) - Exploration strategies (epsilon-greedy, noise injection, etc.) - Experience replay buffers
The action selection process in RL involves: 1. Forward pass through policy/value network 2. Exploration decision (random vs greedy) 3. Action sampling or selection 4. Potential action noise injection
This complex pipeline with dynamic branching is not suitable for JIT compilation.
Workaround for Deep RL Agents: If you need to accelerate inference for deep RL agents (DQN, PPO, SAC, etc.), consider JIT compiling the underlying neural networks separately:
// For DQN agent with Q-network
var dqnAgent = new DQNAgent<double>(options);
// Access the Q-network directly if exposed
// (This requires the agent to expose its networks publicly or via a property)
var qNetwork = dqnAgent.QNetwork; // hypothetical property
// JIT compile the Q-network for faster inference
if (qNetwork.SupportsJitCompilation)
{
var inputNodes = new List<ComputationNode<double>>();
var graphOutput = qNetwork.ExportComputationGraph(inputNodes);
var jitCompiler = new JitCompiler<double>(graphOutput, inputNodes);
// Use jitCompiler.Evaluate() for fast Q-value computation
}
For Tabular RL Agents: Tabular methods (Q-Learning, SARSA, etc.) use lookup tables rather than neural networks. They perform dictionary lookups which cannot be JIT compiled. These agents are already very fast for small state spaces and do not benefit from JIT compilation.
Exceptions
- NotSupportedException
RL agents do not support direct JIT compilation. Use the underlying neural network for JIT compilation if needed.
GetActiveFeatureIndices()
Gets the indices of active features.
public virtual IEnumerable<int> GetActiveFeatureIndices()
Returns
GetFeatureImportance()
Gets feature importance scores.
public virtual Dictionary<string, T> GetFeatureImportance()
Returns
- Dictionary<string, T>
GetMetrics()
Gets the current training metrics.
public virtual Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public abstract ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public abstract Vector<T> GetParameters()
Returns
- Vector<T>
IsFeatureUsed(int)
Checks if a feature is used by the agent.
public virtual bool IsFeatureUsed(int featureIndex)
Parameters
featureIndexint
Returns
LoadModel(string)
Loads the agent's state from a file.
public abstract void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
LoadState(Stream)
Loads the agent's state (parameters and configuration) from a stream.
public virtual void LoadState(Stream stream)
Parameters
streamStreamThe stream to read the agent state from.
Predict(Vector<T>)
Makes a prediction using the trained agent.
public virtual Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
ResetEpisode()
Resets episode-specific state (if any).
public virtual void ResetEpisode()
SaveModel(string)
Saves the agent's state to a file.
public abstract void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SaveState(Stream)
Saves the agent's current state (parameters and configuration) to a stream.
public virtual void SaveState(Stream stream)
Parameters
streamStreamThe stream to write the agent state to.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public abstract Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public abstract byte[] Serialize()
Returns
- byte[]
SetActiveFeatureIndices(IEnumerable<int>)
Sets the active feature indices.
public virtual void SetActiveFeatureIndices(IEnumerable<int> indices)
Parameters
indicesIEnumerable<int>
SetParameters(Vector<T>)
Sets the agent's parameters.
public abstract void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public abstract void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public abstract T Train()
Returns
- T
The training loss for monitoring.
Train(Vector<T>, Vector<T>)
Trains the agent with supervised learning (not supported for RL agents).
public virtual void Train(Vector<T> input, Vector<T> output)
Parameters
inputVector<T>outputVector<T>
WithParameters(Vector<T>)
Creates a new instance with the specified parameters.
public virtual IFullModel<T, Vector<T>, Vector<T>> WithParameters(Vector<T> parameters)
Parameters
parametersVector<T>
Returns
- IFullModel<T, Vector<T>, Vector<T>>