Class REINFORCEAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.REINFORCE
- Assembly
- AiDotNet.dll
REINFORCE (Monte Carlo Policy Gradient) agent for reinforcement learning.
public class REINFORCEAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
REINFORCEAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
REINFORCE is the simplest and most fundamental policy gradient algorithm. It directly optimizes the policy by following the gradient of expected returns. Despite its simplicity, it forms the foundation for many modern RL algorithms.
For Beginners: REINFORCE is the "hello world" of policy gradient methods. The algorithm is beautifully simple:
- Play an entire episode
- Calculate total rewards for each action
- Make good actions more likely, bad actions less likely
Think of it like learning to play a game:
- You play a round
- At the end, you see your score
- You adjust your strategy to do better next time
Pros: Simple, works for any problem, easy to understand Cons: High variance, slow learning, requires complete episodes
Modern algorithms like PPO and A2C improve on REINFORCE's core ideas.
Reference: Williams, R. J. (1992). "Simple statistical gradient-following algorithms for connectionist RL."
Constructors
REINFORCEAgent(REINFORCEOptions<T>)
public REINFORCEAgent(REINFORCEOptions<T> options)
Parameters
optionsREINFORCEOptions<T>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.