Table of Contents

Class ReinforcementLearningAgentBase<T>

Namespace
AiDotNet.ReinforcementLearning.Agents
Assembly
AiDotNet.dll

Base class for all reinforcement learning agents, providing common functionality and structure.

public abstract class ReinforcementLearningAgentBase<T> : IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations (typically float or double).

Inheritance
ReinforcementLearningAgentBase<T>
Implements
IFullModel<T, Vector<T>, Vector<T>>
IModel<Vector<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Vector<T>, Vector<T>>
ICloneable<IFullModel<T, Vector<T>, Vector<T>>>
IGradientComputable<T, Vector<T>, Vector<T>>
Derived
Inherited Members
Extension Methods

Remarks

This abstract base class defines the core structure that all RL agents must follow, ensuring consistency across different RL algorithms while allowing for specialized implementations. It integrates deeply with AiDotNet's existing architecture, using Vector, Matrix, and Tensor types, and following established patterns like OptimizerBase and NeuralNetworkBase.

For Beginners: This is the foundation for all RL agents in AiDotNet.

Think of this base class as the blueprint that defines what every RL agent must be able to do:

  • Select actions based on observations
  • Store experiences for learning
  • Train/update from experiences
  • Save and load trained models
  • Integrate with AiDotNet's neural networks and optimizers

All specific RL algorithms (DQN, PPO, SAC, etc.) inherit from this base and implement their own unique learning logic while sharing common functionality.

Constructors

ReinforcementLearningAgentBase(ReinforcementLearningOptions<T>)

Initializes a new instance of the ReinforcementLearningAgentBase class.

protected ReinforcementLearningAgentBase(ReinforcementLearningOptions<T> options)

Parameters

options ReinforcementLearningOptions<T>

Configuration options for the agent.

Fields

DiscountFactor

Discount factor (gamma) for future rewards.

protected T DiscountFactor

Field Value

T

Episodes

Number of episodes completed.

protected int Episodes

Field Value

int

LearningRate

Learning rate for gradient updates.

protected T LearningRate

Field Value

T

LossFunction

Loss function used for training.

protected readonly ILossFunction<T> LossFunction

Field Value

ILossFunction<T>

LossHistory

History of losses during training.

protected readonly List<T> LossHistory

Field Value

List<T>

NumOps

Numeric operations provider for type T.

protected readonly INumericOperations<T> NumOps

Field Value

INumericOperations<T>

Options

Configuration options for this agent.

protected readonly ReinforcementLearningOptions<T> Options

Field Value

ReinforcementLearningOptions<T>

Random

Random number generator for stochastic operations.

protected readonly Random Random

Field Value

Random

RewardHistory

History of episode rewards.

protected readonly List<T> RewardHistory

Field Value

List<T>

TrainingSteps

Number of training steps completed.

protected int TrainingSteps

Field Value

int

Properties

DefaultLossFunction

Gets the default loss function for this agent.

public virtual ILossFunction<T> DefaultLossFunction { get; }

Property Value

ILossFunction<T>

FeatureCount

Gets the number of input features (state dimensions).

public abstract int FeatureCount { get; }

Property Value

int

FeatureNames

Gets the names of input features.

public virtual string[] FeatureNames { get; }

Property Value

string[]

ParameterCount

Gets the number of parameters in the agent.

public abstract int ParameterCount { get; }

Property Value

int

Remarks

Deep RL agents return parameter counts from neural networks. Classical RL agents (tabular, linear) may have different implementations.

SupportsJitCompilation

Gets whether this RL agent supports JIT compilation.

public virtual bool SupportsJitCompilation { get; }

Property Value

bool

False for the base class. Derived classes may override to return true if they support JIT compilation.

Remarks

Most RL agents do not directly support JIT compilation because: - They use layer-based neural networks without direct computation graph export - Tabular methods use lookup tables rather than mathematical operations - Policy selection often involves dynamic branching based on exploration strategies

Deep RL agents that use neural networks (DQN, PPO, SAC, etc.) may override this to delegate JIT compilation to their underlying policy or value networks if those networks support computation graph export.

For Beginners: JIT compilation speeds up models by converting them to optimized code.

RL agents typically don't support JIT compilation directly because:

  • They combine multiple networks (policy, value, target networks)
  • They use exploration strategies with random decisions
  • The action selection process is complex and dynamic

However, the underlying neural networks used by deep RL agents (like the Q-network in DQN) can potentially be JIT compiled separately for faster inference.

Methods

ApplyGradients(Vector<T>, T)

Applies gradients to update the agent.

public abstract void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>
learningRate T

Clone()

Clones the agent.

public abstract IFullModel<T, Vector<T>, Vector<T>> Clone()

Returns

IFullModel<T, Vector<T>, Vector<T>>

ComputeAverage(IEnumerable<T>)

Computes the average of a collection of values.

protected T ComputeAverage(IEnumerable<T> values)

Parameters

values IEnumerable<T>

Returns

T

ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)

Computes gradients for the agent.

public abstract Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Vector<T>
target Vector<T>
lossFunction ILossFunction<T>

Returns

Vector<T>

DeepCopy()

Creates a deep copy of the agent.

public virtual IFullModel<T, Vector<T>, Vector<T>> DeepCopy()

Returns

IFullModel<T, Vector<T>, Vector<T>>

Deserialize(byte[])

Deserializes the agent from bytes.

public abstract void Deserialize(byte[] data)

Parameters

data byte[]

Dispose()

Disposes of resources used by the agent.

public virtual void Dispose()

ExportComputationGraph(List<ComputationNode<T>>)

Exports the agent's computation graph for JIT compilation.

public virtual ComputationNode<T> ExportComputationGraph(List<ComputationNode<T>> inputNodes)

Parameters

inputNodes List<ComputationNode<T>>

List to populate with input computation nodes.

Returns

ComputationNode<T>

The output computation node representing the agent's prediction.

Remarks

The base RL agent class does not support JIT compilation because RL agents are complex systems that combine multiple components: - Policy networks (select actions) - Value networks (estimate state/action values) - Target networks (provide stable training targets) - Exploration strategies (epsilon-greedy, noise injection, etc.) - Experience replay buffers

The action selection process in RL involves: 1. Forward pass through policy/value network 2. Exploration decision (random vs greedy) 3. Action sampling or selection 4. Potential action noise injection

This complex pipeline with dynamic branching is not suitable for JIT compilation.

Workaround for Deep RL Agents: If you need to accelerate inference for deep RL agents (DQN, PPO, SAC, etc.), consider JIT compiling the underlying neural networks separately:

// For DQN agent with Q-network
var dqnAgent = new DQNAgent<double>(options);

// Access the Q-network directly if exposed
// (This requires the agent to expose its networks publicly or via a property)
var qNetwork = dqnAgent.QNetwork; // hypothetical property

// JIT compile the Q-network for faster inference
if (qNetwork.SupportsJitCompilation)
{
    var inputNodes = new List<ComputationNode<double>>();
    var graphOutput = qNetwork.ExportComputationGraph(inputNodes);
    var jitCompiler = new JitCompiler<double>(graphOutput, inputNodes);
    // Use jitCompiler.Evaluate() for fast Q-value computation
}

For Tabular RL Agents: Tabular methods (Q-Learning, SARSA, etc.) use lookup tables rather than neural networks. They perform dictionary lookups which cannot be JIT compiled. These agents are already very fast for small state spaces and do not benefit from JIT compilation.

Exceptions

NotSupportedException

RL agents do not support direct JIT compilation. Use the underlying neural network for JIT compilation if needed.

GetActiveFeatureIndices()

Gets the indices of active features.

public virtual IEnumerable<int> GetActiveFeatureIndices()

Returns

IEnumerable<int>

GetFeatureImportance()

Gets feature importance scores.

public virtual Dictionary<string, T> GetFeatureImportance()

Returns

Dictionary<string, T>

GetMetrics()

Gets the current training metrics.

public virtual Dictionary<string, T> GetMetrics()

Returns

Dictionary<string, T>

Dictionary of metric names to values.

GetModelMetadata()

Gets model metadata.

public abstract ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameters()

Gets the agent's parameters.

public abstract Vector<T> GetParameters()

Returns

Vector<T>

IsFeatureUsed(int)

Checks if a feature is used by the agent.

public virtual bool IsFeatureUsed(int featureIndex)

Parameters

featureIndex int

Returns

bool

LoadModel(string)

Loads the agent's state from a file.

public abstract void LoadModel(string filepath)

Parameters

filepath string

Path to load the agent from.

LoadState(Stream)

Loads the agent's state (parameters and configuration) from a stream.

public virtual void LoadState(Stream stream)

Parameters

stream Stream

The stream to read the agent state from.

Predict(Vector<T>)

Makes a prediction using the trained agent.

public virtual Vector<T> Predict(Vector<T> input)

Parameters

input Vector<T>

Returns

Vector<T>

ResetEpisode()

Resets episode-specific state (if any).

public virtual void ResetEpisode()

SaveModel(string)

Saves the agent's state to a file.

public abstract void SaveModel(string filepath)

Parameters

filepath string

Path to save the agent.

SaveState(Stream)

Saves the agent's current state (parameters and configuration) to a stream.

public virtual void SaveState(Stream stream)

Parameters

stream Stream

The stream to write the agent state to.

SelectAction(Vector<T>, bool)

Selects an action given the current state observation.

public abstract Vector<T> SelectAction(Vector<T> state, bool training = true)

Parameters

state Vector<T>

The current state observation as a Vector.

training bool

Whether the agent is in training mode (affects exploration).

Returns

Vector<T>

Action as a Vector (can be discrete or continuous).

Serialize()

Serializes the agent to bytes.

public abstract byte[] Serialize()

Returns

byte[]

SetActiveFeatureIndices(IEnumerable<int>)

Sets the active feature indices.

public virtual void SetActiveFeatureIndices(IEnumerable<int> indices)

Parameters

indices IEnumerable<int>

SetParameters(Vector<T>)

Sets the agent's parameters.

public abstract void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)

Stores an experience tuple for later learning.

public abstract void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)

Parameters

state Vector<T>

The state before action.

action Vector<T>

The action taken.

reward T

The reward received.

nextState Vector<T>

The state after action.

done bool

Whether the episode terminated.

Train()

Performs one training step, updating the agent's policy/value function.

public abstract T Train()

Returns

T

The training loss for monitoring.

Train(Vector<T>, Vector<T>)

Trains the agent with supervised learning (not supported for RL agents).

public virtual void Train(Vector<T> input, Vector<T> output)

Parameters

input Vector<T>
output Vector<T>

WithParameters(Vector<T>)

Creates a new instance with the specified parameters.

public virtual IFullModel<T, Vector<T>, Vector<T>> WithParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Returns

IFullModel<T, Vector<T>, Vector<T>>