Class A3CAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.A3C
- Assembly
- AiDotNet.dll
Asynchronous Advantage Actor-Critic (A3C) agent for reinforcement learning.
public class A3CAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
A3CAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
A3C runs multiple agents in parallel, each exploring different strategies. Workers periodically synchronize with a global network, enabling diverse exploration without replay buffers.
For Beginners: A3C is like having multiple students learn simultaneously - each has different experiences, and they periodically share knowledge with a "master" network. This parallel learning provides stability and diverse exploration.
Key features:
- Asynchronous Updates: Multiple workers update global network independently
- No Replay Buffer: On-policy learning with parallel exploration
- Actor-Critic: Learns both policy and value function
- Diverse Exploration: Each worker explores differently
Famous for: DeepMind's breakthrough (2016), enables CPU-only training
Constructors
A3CAgent(A3COptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)
public A3CAgent(A3COptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)
Parameters
optionsA3COptions<T>optimizerIOptimizer<T, Vector<T>, Vector<T>>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
Predict(Vector<T>)
Makes a prediction using the trained agent.
public override Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictAsync(Vector<T>)
public Task<Vector<T>> PredictAsync(Vector<T> input)
Parameters
inputVector<T>
Returns
- Task<Vector<T>>
ResetEpisode()
Resets episode-specific state (if any).
public override void ResetEpisode()
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.
TrainAsync()
public Task TrainAsync()
Returns
TrainAsync(IEnvironment<T>, int)
Train A3C with parallel workers (simplified for single-threaded execution). In production, this would spawn actual parallel tasks.
public Task TrainAsync(IEnvironment<T> environment, int maxSteps)
Parameters
environmentIEnvironment<T>maxStepsint