Class DecisionTransformerAgent<T>
- Assembly
- AiDotNet.dll
Decision Transformer agent for offline reinforcement learning.
public class DecisionTransformerAgent<T> : DeepReinforcementLearningAgentBase<T>, IRLAgent<T>, IFullModel<T, Vector<T>, Vector<T>>, IModel<Vector<T>, Vector<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Vector<T>, Vector<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Vector<T>, Vector<T>>>, IGradientComputable<T, Vector<T>, Vector<T>>, IJitCompilable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DecisionTransformerAgent<T>
- Implements
-
IRLAgent<T>
- Inherited Members
- Extension Methods
Remarks
Decision Transformer treats RL as sequence modeling, using transformer architecture to predict actions conditioned on desired returns-to-go.
For Beginners: Instead of learning "what's the best action", Decision Transformer learns "what action was taken when the outcome was X". At test time, you specify the desired outcome, and it generates the action sequence.
Key innovation:
- Return Conditioning: Specify target return, get actions that achieve it
- Sequence Modeling: Uses transformers like GPT for temporal dependencies
- No RL Updates: Just supervised learning on (return, state, action) sequences
- Offline-First: Designed for learning from fixed datasets
Think of it as: "Show me examples of successful games, and I'll learn to generate moves that lead to that level of success."
Famous for: Berkeley/Meta research simplifying RL to sequence modeling
Constructors
DecisionTransformerAgent(DecisionTransformerOptions<T>, IOptimizer<T, Vector<T>, Vector<T>>?)
public DecisionTransformerAgent(DecisionTransformerOptions<T> options, IOptimizer<T, Vector<T>, Vector<T>>? optimizer = null)
Parameters
optionsDecisionTransformerOptions<T>optimizerIOptimizer<T, Vector<T>, Vector<T>>
Properties
FeatureCount
Gets the number of input features (state dimensions).
public override int FeatureCount { get; }
Property Value
Methods
ApplyGradients(Vector<T>, T)
Applies gradients to update the agent.
public override void ApplyGradients(Vector<T> gradients, T learningRate)
Parameters
gradientsVector<T>learningRateT
Clone()
Clones the agent.
public override IFullModel<T, Vector<T>, Vector<T>> Clone()
Returns
- IFullModel<T, Vector<T>, Vector<T>>
ComputeGradients(Vector<T>, Vector<T>, ILossFunction<T>?)
Computes gradients for the agent.
public override Vector<T> ComputeGradients(Vector<T> input, Vector<T> target, ILossFunction<T>? lossFunction = null)
Parameters
inputVector<T>targetVector<T>lossFunctionILossFunction<T>
Returns
- Vector<T>
Deserialize(byte[])
Deserializes the agent from bytes.
public override void Deserialize(byte[] data)
Parameters
databyte[]
GetMetrics()
Gets the current training metrics.
public override Dictionary<string, T> GetMetrics()
Returns
- Dictionary<string, T>
Dictionary of metric names to values.
GetModelMetadata()
Gets model metadata.
public override ModelMetadata<T> GetModelMetadata()
Returns
GetParameters()
Gets the agent's parameters.
public override Vector<T> GetParameters()
Returns
- Vector<T>
LoadModel(string)
Loads the agent's state from a file.
public override void LoadModel(string filepath)
Parameters
filepathstringPath to load the agent from.
LoadOfflineData(List<List<(Vector<T> state, Vector<T> action, T reward)>>)
Load offline dataset into the trajectory buffer. Dataset should contain complete trajectories with computed returns-to-go.
public void LoadOfflineData(List<List<(Vector<T> state, Vector<T> action, T reward)>> trajectories)
Parameters
Predict(Vector<T>)
Makes a prediction using the trained agent.
public override Vector<T> Predict(Vector<T> input)
Parameters
inputVector<T>
Returns
- Vector<T>
PredictAsync(Vector<T>)
public Task<Vector<T>> PredictAsync(Vector<T> input)
Parameters
inputVector<T>
Returns
- Task<Vector<T>>
ResetEpisode()
Resets episode-specific state (if any).
public override void ResetEpisode()
SaveModel(string)
Saves the agent's state to a file.
public override void SaveModel(string filepath)
Parameters
filepathstringPath to save the agent.
SelectAction(Vector<T>, bool)
Selects an action given the current state observation.
public override Vector<T> SelectAction(Vector<T> state, bool training = true)
Parameters
stateVector<T>The current state observation as a Vector.
trainingboolWhether the agent is in training mode (affects exploration).
Returns
- Vector<T>
Action as a Vector (can be discrete or continuous).
SelectActionWithReturn(Vector<T>, T, bool)
Select action conditioned on desired return-to-go.
public Vector<T> SelectActionWithReturn(Vector<T> state, T targetReturn, bool training = true)
Parameters
stateVector<T>targetReturnTtrainingbool
Returns
- Vector<T>
Serialize()
Serializes the agent to bytes.
public override byte[] Serialize()
Returns
- byte[]
SetParameters(Vector<T>)
Sets the agent's parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>
StoreExperience(Vector<T>, Vector<T>, T, Vector<T>, bool)
Stores an experience tuple for later learning.
public override void StoreExperience(Vector<T> state, Vector<T> action, T reward, Vector<T> nextState, bool done)
Parameters
stateVector<T>The state before action.
actionVector<T>The action taken.
rewardTThe reward received.
nextStateVector<T>The state after action.
doneboolWhether the episode terminated.
Train()
Performs one training step, updating the agent's policy/value function.
public override T Train()
Returns
- T
The training loss for monitoring.
TrainAsync()
public Task TrainAsync()