Class MADDPGAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.MADDPG
- Assembly
- AiDotNet.dll
public class MADDPGAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
T
- Inheritance
-
MADDPGAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Constructors
MADDPGAgent(MADDPGOptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)
public MADDPGAgent(MADDPGOptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)
Parameters
optionsMADDPGOptions<T>optimizerIOptimizer<T, Vector<T>, Vector<T>>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Not supported for MADDPGAgent. Use the agent's internal Train() loop instead.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>Not used.
learningRateTNot used.
Exceptions
- NotSupportedException
Always thrown. MADDPG manages gradient computation and parameter updates internally through backpropagation.
Clone()
Creates a deep copy of this MADDPG agent including all trained network weights.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
A new MADDPG agent with the same configuration and trained parameters.
Remarks
Issue #5 fix: Clone now properly copies all trained weights from actor and critic networks using GetParameters() and SetParameters(), ensuring the cloned agent has the same learned behavior.
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes a MADDPG agent from a byte array.
public override void Deserialize(byte[] data)
Parameters
databyte[]Byte array containing the serialized agent data.
Remarks
Expects data created by Serialize() with a compatible configuration.
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads a trained model from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the model from.
Remarks
Uses Deserialize(byte[]) to restore network weights.
Predict(Vector<T>)
Makes a prediction using the trained agent.
public override Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictAsync(Vector<T>)
public Task<Vector<T>> PredictAsync(Vector<T> input)
Parameters
inputVector<T>
Returns
- Task<Vector<T>>
ResetEpisode()
Resets episode-specific state (if any).
public override void ResetEpisode()
SaveModel(string)
Saves the trained model to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the model.
Remarks
Uses Serialize() to persist network weights.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
SelectActionForAgent(int, Vector<T>, bool)
public Vector<T> SelectActionForAgent(int agentId, Vector<T> state, bool training = true)
Parameters
Returns
- Vector<T>
Serialize()
Serializes the MADDPG agent to a byte array.
public override byte[] Serialize()
Returns
- byte[]
Byte array containing the serialized agent data.
Remarks
Serializes configuration values and all actor/critic network weights.
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
StoreMultiAgentExperience(List<Vector<T>>, List<Vector<T>>, List<T>, List<Vector<T>>, bool)
Store multi-agent experience with per-agent reward tracking.
public void StoreMultiAgentExperience(List<Vector<T>> states, List<Vector<T>> actions, List<T> rewards, List<Vector<T>> nextStates, bool done)
Parameters
Remarks
Stores individual rewards for each agent to support both cooperative and competitive/mixed-motive scenarios. For backward compatibility, also stores an averaged reward in the experience.
The per-agent rewards are stored keyed by the buffer index where the experience will be placed. This accounts for the circular buffer behavior when capacity is reached.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.
TrainAsync()
public Task TrainAsync()