Table of Contents

Class DecisionTransformerAgent<T>

Decision Transformer agent for offline reinforcement learning.

public class DecisionTransformerAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
DecisionTransformerAgent<T>
Implements
IFullModel<T, Vector<T>, Vector<T>>
IModel<Vector<T>, Vector<T>, ModelMetadata<T>>
IParameterizable<T, Vector<T>, Vector<T>>
ICloneable<IFullModel<T, Vector<T>, Vector<T>>>
IGradientComputable<T, Vector<T>, Vector<T>>
Inherited Members
Extension Methods

Remarks

Decision Transformer treats RL as sequence modeling, using transformer architecture to predict actions conditioned on desired returns-to-go.

For Beginners: Instead of learning "what's the best action", Decision Transformer learns "what action was taken when the outcome was X". At test time, you specify the desired outcome, and it generates the action sequence.

Key innovation:

  • Return Conditioning: Specify target return, get actions that achieve it
  • Sequence Modeling: Uses transformers like GPT for temporal dependencies
  • No RL Updates: Just supervised learning on (return, state, action) sequences
  • Offline-First: Designed for learning from fixed datasets

Think of it as: "Show me examples of successful games, and I'll learn to generate moves that lead to that level of success."

Famous for: Berkeley/Meta research simplifying RL to sequence modeling

Constructors

DecisionTransformerAgent(DecisionTransformerOptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)

public DecisionTransformerAgent(DecisionTransformerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)

Parameters

options DecisionTransformerOptions<T>
optimizer IOptimizer<T, Vector<T>, Vector<T>>

Properties

FeatureCount

Gets the number of input features (state dimensions).

public override int FeatureCount { get; }

Property Value

int

Methods

ApplyGradients(Vector<T>, T)

Applies gradients to update the agent.

public override void ApplyGradients(Vector<T> gradients, T learningRate)

Parameters

gradients Vector<T>
learningRate T

Clone()

Clones the agent.

public override IFullModel<T, Vector<T>, Vector<T>> Clone()

Returns

IFullModel<T, Vector<T>, Vector<T>>

ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)

Computes gradients for the agent.

public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)

Parameters

input Vector<T>
target Vector<T>
lossFunction ILossFunction<T>

Returns

Vector<T>

Deserialize(byte[])

Deserializes the agent from bytes.

public override void Deserialize(byte[] data)

Parameters

data byte[]

GetMetrics()

Gets the current training metrics.

public override Dictionary<string, T> GetMetrics()

Returns

Dictionary<string, T>

Dictionary of metric names to values.

GetModelMetadata()

Gets model metadata.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

GetParameters()

Gets the agent's parameters.

public override Vector<T> GetParameters()

Returns

Vector<T>

LoadModel(string)

Loads the agent's state from a file.

public override void LoadModel(string filepath)

Parameters

filepath string

Path to load the agent from.

LoadOfflineData(List<List<(Vector<T> state, Vector<T> action, T reward)>>)

Load offline dataset into the trajectory buffer. Dataset should contain complete trajectories with computed returns-to-go.

public void LoadOfflineData(List<List<(Vector<T> state, Vector<T> action, T reward)>> trajectories)

Parameters

trajectories List<List<(Vector<T> state, Vector<T> action, T reward)>>

Predict(Vector<T>)

Makes a prediction using the trained agent.

public override Vector<T> Predict(Vector<T> input)

Parameters

input Vector<T>

Returns

Vector<T>

PredictAsync(Vector<T>)

public Task<Vector<T>> PredictAsync(Vector<T> input)

Parameters

input Vector<T>

Returns

Task<Vector<T>>

ResetEpisode()

Resets episode-specific state (if any).

public override void ResetEpisode()

SaveModel(string)

Saves the agent's state to a file.

public override void SaveModel(string filepath)

Parameters

filepath string

Path to save the agent.

SelectAction(Vector<T>, bool)

Selects an action given the current state observation.

public override Vector<T> SelectAction(Vector<T> state, bool training = true)

Parameters

state Vector<T>

The current state observation as a Vector.

training bool

Whether the agent is in training mode (affects exploration).

Returns

Vector<T>

Action as a Vector (can be discrete or continuous).

SelectActionWithReturn(Vector<T>, T, bool)

Select action conditioned on desired return-to-go.

public Vector<T> SelectActionWithReturn(Vector<T> state, T targetReturn, bool training = true)

Parameters

state Vector<T>
targetReturn T
training bool

Returns

Vector<T>

Serialize()

Serializes the agent to bytes.

public override byte[] Serialize()

Returns

byte[]

SetParameters(Vector<T>)

Sets the agent's parameters.

public override void SetParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)

Stores an experience tuple for later learning.

public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)

Parameters

state Vector<T>

The state before action.

action Vector<T>

The action taken.

reward T

The reward received.

nextState Vector<T>

The state after action.

done bool

Whether the episode terminated.

Train()

Performs one training step, updating the agent's policy/value function.

public override T Train()

Returns

T

The training loss for monitoring.

TrainAsync()

public Task TrainAsync()

Returns

Task