Class DreamerAgent<T>
- Namespace
- AiDotNet.ReinforcementLearning.Agents.Dreamer
- Assembly
- AiDotNet.dll
Dreamer agent for model-based reinforcement learning.
public class DreamerAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DreamerAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
Dreamer learns a world model in latent space and uses it for planning. It combines representation learning, dynamics modeling, and policy learning.
For Beginners: Dreamer learns a "mental model" of how the environment works, then uses that model to imagine future scenarios and plan actions - like chess players thinking multiple moves ahead.
Key components:
- Representation Network: Encodes observations to latent states
- Dynamics Model: Predicts next latent state
- Reward Model: Predicts rewards
- Value Network: Estimates state values
- Actor Network: Learns policy in imagination
Think of it as: First learn physics by observation, then use that knowledge to predict "what happens if I do X" without actually doing it.
Advantages: Sample efficient, works with images, enables planning
Constructors
DreamerAgent(DreamerOptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)
public DreamerAgent(DreamerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)
Parameters
optionsDreamerOptions<T>optimizerIOptimizer<T, Vector<T>, Vector<T>>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for supervised learning scenarios.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Remarks
FIX ISSUE 9: This method uses simple supervised loss for compatibility with base class API. It does NOT match the agent's internal training procedure which uses:
- World model losses (dynamics, reward, continue prediction)
- Imagination-based policy gradients
- Value function TD errors
For actual agent training, use Train() which implements the full Dreamer algorithm. This method is provided only for API compatibility and simple supervised fine-tuning scenarios.
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
Predict(Vector<T>)
Makes a prediction using the trained agent.
public override Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictAsync(Vector<T>)
public Task<Vector<T>> PredictAsync(Vector<T> input)
Parameters
inputVector<T>
Returns
- Task<Vector<T>>
ResetEpisode()
Resets episode-specific state (if any).
public override void ResetEpisode()
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> observation, bool training = true)
Parameters
observationVector<T>trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> observation, Vector<T> action, T reward, Vector<T> nextObservation, bool done)
Parameters
observationVector<T>actionVector<T>The action taken.
rewardTThe reward received.
nextObservationVector<T>doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.
TrainAsync()
public Task TrainAsync()