Table of Contents

Class DreamerAgent<T>

Namespace
AiDotNet.ReinforcementLearning.Agents.Dreamer
Assembly
AiDotNet.dll

Dreamer agent for model-based reinforcement learning.

public class DreamerAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
DreamerAgent<T>
Implements
IFullModel<T, Vector<T>, Vector<T>>
IModel<Vector<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Vector<T>, Vector<T>>
ICloneable<IFullModel<T, Vector<T>, Vector<T>>>
IGradientComputable<T, Vector<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

Dreamer learns a world model in latent space and uses it for planning. It combines representation learning, dynamics modeling, and policy learning.

For Beginners: Dreamer learns a "mental model" of how the environment works, then uses that model to imagine future scenarios and plan actions - like chess players thinking multiple moves ahead.

Key components:

  • Representation Network: Encodes observations to latent states
  • Dynamics Model: Predicts next latent state
  • Reward Model: Predicts rewards
  • Value Network: Estimates state values
  • Actor Network: Learns policy in imagination

Think of it as: First learn physics by observation, then use that knowledge to predict "what happens if I do X" without actually doing it.

Advantages: Sample efficient, works with images, enables planning

Constructors

DreamerAgent(DreamerOptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)

public DreamerAgent(DreamerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)

Parameters

options DreamerOptions<T>
optimizer IOptimizer<T, Vector<T>, Vector<T>>

Properties

FeatureCount

Gets the number of input features (state dimensions).

public override int FeatureCount { get; }

Property Value

int

Methods

ApplyGradients(Vector<T>, T)

Applies gradients to update the agent.

public override void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>
learningRate T

Clone()

Clones the agent.

public override IFullModel<T, Vector<T>, Vector<T>> Clone()

Returns

IFullModel<T, Vector<T>, Vector<T>>

ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)

Computes gradients for supervised learning scenarios.

public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Vector<T>
target Vector<T>
lossFunction ILossFunction<T>

Returns

Vector<T>

Remarks

FIX ISSUE 9: This method uses simple supervised loss for compatibility with base class API. It does NOT match the agent's internal training procedure which uses:

  • World model losses (dynamics, reward, continue prediction)
  • Imagination-based policy gradients
  • Value function TD errors

For actual agent training, use Train() which implements the full Dreamer algorithm. This method is provided only for API compatibility and simple supervised fine-tuning scenarios.

Deserialize(byte[])

Deserializes the agent from bytes.

public override void Deserialize(byte[] data)

Parameters

data byte[]

GetMetrics()

Gets the current training metrics.

public override Dictionary<string, T> GetMetrics()

Returns

Dictionary<string, T>

Dictionary of metric names to values.

GetModelMetadata()

Gets model metadata.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameters()

Gets the agent's parameters.

public override Vector<T> GetParameters()

Returns

Vector<T>

LoadModel(string)

Loads the agent's state from a file.

public override void LoadModel(string filepath)

Parameters

filepath string

Path to load the agent from.

Predict(Vector<T>)

Makes a prediction using the trained agent.

public override Vector<T> Predict(Vector<T> input)

Parameters

input Vector<T>

Returns

Vector<T>

PredictAsync(Vector<T>)

public Task<Vector<T>> PredictAsync(Vector<T> input)

Parameters

input Vector<T>

Returns

Task<Vector<T>>

ResetEpisode()

Resets episode-specific state (if any).

public override void ResetEpisode()

SaveModel(string)

Saves the agent's state to a file.

public override void SaveModel(string filepath)

Parameters

filepath string

Path to save the agent.

SelectAction(Vector<T>, bool)

Selects an action given the current state observation.

public override Vector<T> SelectAction(Vector<T> observation, bool training = true)

Parameters

observation Vector<T>
training bool

Whether the agent is in training mode (affects exploration).

Returns

Vector<T>

Action as a Vector (can be discrete or continuous).

Serialize()

Serializes the agent to bytes.

public override byte[] Serialize()

Returns

byte[]

SetParameters(Vector<T>)

Sets the agent's parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)

Stores an experience tuple for later learning.

public override void StoreExperience(Vector<T> observation, Vector<T> action, T reward, Vector<T> nextObservation, bool done)

Parameters

observation Vector<T>
action Vector<T>

The action taken.

reward T

The reward received.

nextObservation Vector<T>
done bool

Whether the episode terminated.

Train()

Performs one training step, updating the agent's policy/value function.

public override T Train()

Returns

T

The training loss for monitoring.

TrainAsync()

public Task TrainAsync()

Returns

Task