Class MuZeroAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.MuZero
- Assembly
- AiDotNet.dll
MuZero agent combining tree search with learned models.
public class MuZeroAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
MuZeroAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
MuZero combines tree search (like AlphaZero) with learned dynamics. It masters games without knowing the rules, learning its own internal model.
For Beginners: MuZero is DeepMind's breakthrough that achieved superhuman performance in Atari, Go, Chess, and Shogi without being told the rules. It learns its own "internal model" of the game and uses tree search to plan ahead.
Three key networks:
- Representation: Observation -> hidden state
- Dynamics: (hidden state, action) -> (next hidden state, reward)
- Prediction: hidden state -> (policy, value)
Plus tree search (MCTS) for planning using the learned model.
Think of it as: Learning chess by watching games, figuring out the rules yourself, then planning moves by mentally simulating the game.
Famous for: Superhuman Atari/board games without knowing rules
Constructors
MuZeroAgent(MuZeroOptions<T>)
public MuZeroAgent(MuZeroOptions<T> options)
Parameters
optionsMuZeroOptions<T>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
Predict(Vector<T>)
Makes a prediction using the trained agent.
public override Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictAsync(Vector<T>)
public Task<Vector<T>> PredictAsync(Vector<T> input)
Parameters
inputVector<T>
Returns
- Task<Vector<T>>
ResetEpisode()
Resets episode-specific state (if any).
public override void ResetEpisode()
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> observation, bool training = true)
Parameters
observationVector<T>trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> observation, Vector<T> action, T reward, Vector<T> nextObservation, bool done)
Parameters
observationVector<T>actionVector<T>The action taken.
rewardTThe reward received.
nextObservationVector<T>doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.
TrainAsync()
public Task TrainAsync()