Table of Contents

Class REINFORCEAgent<T>

Namespace
AiDotNet.ReinforcementLearning.Agents.REINFORCE
Assembly
AiDotNet.dll

REINFORCE (Monte Carlo Policy Gradient) agent for reinforcement learning.

public class REINFORCEAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
REINFORCEAgent<T>
Implements
IFullModel<T, Vector<T>, Vector<T>>
IModel<Vector<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Vector<T>, Vector<T>>
ICloneable<IFullModel<T, Vector<T>, Vector<T>>>
IGradientComputable<T, Vector<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

REINFORCE is the simplest and most fundamental policy gradient algorithm. It directly optimizes the policy by following the gradient of expected returns. Despite its simplicity, it forms the foundation for many modern RL algorithms.

For Beginners: REINFORCE is the "hello world" of policy gradient methods. The algorithm is beautifully simple:

  1. Play an entire episode
  2. Calculate total rewards for each action
  3. Make good actions more likely, bad actions less likely

Think of it like learning to play a game:

  • You play a round
  • At the end, you see your score
  • You adjust your strategy to do better next time

Pros: Simple, works for any problem, easy to understand Cons: High variance, slow learning, requires complete episodes

Modern algorithms like PPO and A2C improve on REINFORCE's core ideas.

Reference: Williams, R. J. (1992). "Simple statistical gradient-following algorithms for connectionist RL."

Constructors

REINFORCEAgent(REINFORCEOptions<T>)

public REINFORCEAgent(REINFORCEOptions<T> options)

Parameters

options REINFORCEOptions<T>

Properties

FeatureCount

Gets the number of input features (state dimensions).

public override int FeatureCount { get; }

Property Value

int

Methods

ApplyGradients(Vector<T>, T)

Applies gradients to update the agent.

public override void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>
learningRate T

Clone()

Clones the agent.

public override IFullModel<T, Vector<T>, Vector<T>> Clone()

Returns

IFullModel<T, Vector<T>, Vector<T>>

ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)

Computes gradients for the agent.

public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Vector<T>
target Vector<T>
lossFunction ILossFunction<T>

Returns

Vector<T>

Deserialize(byte[])

Deserializes the agent from bytes.

public override void Deserialize(byte[] data)

Parameters

data byte[]

GetMetrics()

Gets the current training metrics.

public override Dictionary<string, T> GetMetrics()

Returns

Dictionary<string, T>

Dictionary of metric names to values.

GetModelMetadata()

Gets model metadata.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameters()

Gets the agent's parameters.

public override Vector<T> GetParameters()

Returns

Vector<T>

LoadModel(string)

Loads the agent's state from a file.

public override void LoadModel(string filepath)

Parameters

filepath string

Path to load the agent from.

SaveModel(string)

Saves the agent's state to a file.

public override void SaveModel(string filepath)

Parameters

filepath string

Path to save the agent.

SelectAction(Vector<T>, bool)

Selects an action given the current state observation.

public override Vector<T> SelectAction(Vector<T> state, bool training = true)

Parameters

state Vector<T>

The current state observation as a Vector.

training bool

Whether the agent is in training mode (affects exploration).

Returns

Vector<T>

Action as a Vector (can be discrete or continuous).

Serialize()

Serializes the agent to bytes.

public override byte[] Serialize()

Returns

byte[]

SetParameters(Vector<T>)

Sets the agent's parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)

Stores an experience tuple for later learning.

public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)

Parameters

state Vector<T>

The state before action.

action Vector<T>

The action taken.

reward T

The reward received.

nextState Vector<T>

The state after action.

done bool

Whether the episode terminated.

Train()

Performs one training step, updating the agent's policy/value function.

public override T Train()

Returns

T

The training loss for monitoring.