Class ReLoRAAdapter<T>
Restart LoRA (ReLoRA) adapter that periodically merges and restarts LoRA training for continual learning.
public class ReLoRAAdapter<T> : LoRAAdapterBase<T>, IDisposable, ILoRAAdapter<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LayerBase<T>ReLoRAAdapter<T>
- Implements
-
ILoRAAdapter<T>ILayer<T>
- Inherited Members
Remarks
ReLoRA addresses the challenge of continual learning and long-running training by periodically: 1. Merging the LoRA weights into the base layer (accumulating the adaptation) 2. Resetting the LoRA matrices to restart training fresh 3. Continuing training with a clean slate while preserving previous learning
This approach: - Prevents catastrophic forgetting by accumulating adaptations into the base layer - Allows continuous adaptation to new data without losing old knowledge - Maintains parameter efficiency by resetting LoRA to small matrices - Enables training on continuously evolving data streams
For Beginners: ReLoRA is like having multiple rounds of LoRA training.
Imagine you're fine-tuning a model on data that keeps changing:
- Round 1: Train LoRA on dataset A for 1000 steps
- Merge: Add the learned changes into the base model
- Restart: Reset LoRA matrices and train on dataset B for 1000 steps
- Merge: Add these new changes to the (already updated) base model
- Repeat...
Benefits:
- Continual learning: Can keep learning from new data indefinitely
- No catastrophic forgetting: Old knowledge is preserved in the base layer
- Parameter efficient: LoRA matrices stay small even after many restarts
- Flexible: Can adapt to distribution shifts and new tasks
How it works:
- Train normally with LoRA for N steps (restart interval)
- At step N: Merge LoRA weights → AccumulatedWeight += LoRA
- Reset LoRA matrices to zero (fresh start)
- Continue training for another N steps
- Repeat indefinitely
Use cases:
- Training on streaming data (news articles, user behavior, etc.)
- Adapting to distribution shifts over time
- Long-running training sessions that need checkpoints
- Multi-task learning with periodic task switches
Reference: "ReLoRA: High-Rank Training Through Low-Rank Updates" (2023) https://arxiv.org/abs/2307.05695
Constructors
ReLoRAAdapter(ILayer<T>, int, double, int, bool, bool, int)
Initializes a new ReLoRA adapter with restart-based continual learning.
public ReLoRAAdapter(ILayer<T> baseLayer, int rank, double alpha = -1, int restartInterval = 1000, bool freezeBaseLayer = true, bool useWarmup = true, int warmupSteps = 10)
Parameters
baseLayerILayer<T>The layer to adapt with ReLoRA.
rankintThe rank of the LoRA decomposition.
alphadoubleThe LoRA scaling factor (defaults to rank if negative).
restartIntervalintNumber of steps between restart operations (default: 1000).
freezeBaseLayerboolWhether to freeze the base layer's parameters during training (default: true).
useWarmupboolWhether to use warmup after restarts (default: true).
warmupStepsintNumber of warmup steps after restart (default: 10).
Remarks
For Beginners: This creates a ReLoRA adapter for continual learning.
Parameters:
- baseLayer: The layer you want to adapt continuously
- rank: Size of the LoRA matrices (lower = more efficient)
- alpha: Strength of the LoRA adaptation
- restartInterval: How often to merge and restart (in training steps)
- freezeBaseLayer: Lock the base layer weights (typical for LoRA)
- useWarmup: Use reduced learning rate after restarts (helps stability)
- warmupSteps: How many steps to warm up for
The adapter will automatically handle merging and restarting at the specified interval. You just train normally, and it takes care of the restart logic.
Exceptions
- ArgumentNullException
Thrown when baseLayer is null.
- ArgumentException
Thrown when restartInterval is invalid.
Properties
CurrentStep
Gets the current step within the current restart cycle.
public int CurrentStep { get; }
Property Value
RestartCount
Gets the total number of restarts that have occurred.
public int RestartCount { get; }
Property Value
RestartInterval
Gets the number of training steps between restarts.
public int RestartInterval { get; }
Property Value
Methods
Backward(Tensor<T>)
Performs the backward pass through all components.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient flowing back from the next layer.
Returns
- Tensor<T>
Gradient to pass to the previous layer.
Remarks
Gradients flow through: 1. Current LoRA layer (always updated) 2. Base layer (only if not frozen) 3. Accumulated weights (updated to reflect gradient contribution)
For Beginners: This is where learning happens! Gradients flow backward: - Update current LoRA matrices - Update base layer if not frozen - Note: Accumulated weights are treated as constants during backprop (they only change during restart, not during normal training)
ForceRestart()
Manually triggers a restart (useful for checkpointing or manual control).
public void ForceRestart()
Remarks
For Beginners: This forces an immediate restart, even if the interval hasn't been reached. Useful when you want to checkpoint at specific points (e.g., after completing a task or dataset).
Forward(Tensor<T>)
Performs the forward pass with accumulated LoRA adaptations.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Base layer output plus accumulated LoRA adaptation plus current LoRA output.
Remarks
The forward pass computes: output = base_layer(input) + input * AccumulatedWeight + lora_layer(input)
For Beginners: This processes input through three components: 1. Base layer: The original layer (may have been adapted in previous cycles) 2. Accumulated weight: All previous LoRA cycles' learning 3. Current LoRA: The current cycle's adaptation
All three are added together to produce the final output.
GetAccumulatedWeight()
Gets a copy of the accumulated weight matrix.
public Matrix<T> GetAccumulatedWeight()
Returns
- Matrix<T>
MergeToOriginalLayer()
Merges all accumulated adaptations into the base layer and returns the merged layer.
public override ILayer<T> MergeToOriginalLayer()
Returns
- ILayer<T>
A new layer with all ReLoRA adaptations (accumulated + current) merged into the base layer's weights.
Remarks
This merges: 1. All accumulated LoRA weights from previous restart cycles 2. The current LoRA cycle's weights into the base layer's weights to create a final standalone layer.
For Beginners: This "bakes in" all the ReLoRA learning into a regular layer.
This takes:
- All previous cycles' learning (from accumulated weights)
- Current cycle's learning (from current LoRA)
- Base layer weights
And combines them into a single layer that:
- Works like a normal layer (no special ReLoRA infrastructure needed)
- Contains all the adapted knowledge
- Can be deployed for fast inference
ResetState()
Resets the internal state of all layers.
public override void ResetState()
ResetStepCounter()
Resets the step counter without performing a restart (useful for aligning with external training loops).
public void ResetStepCounter()
RestartLoRA()
Performs the restart operation: merges current LoRA weights and reinitializes.
public void RestartLoRA()
Remarks
The restart process: 1. Merge current LoRA weights: W_accumulated += W_A * W_B * scaling 2. Reinitialize LoRA matrices: A gets new random values, B reset to zero 3. Reset step counter to 0 4. Increment restart count
For Beginners: This performs the "checkpoint and restart" operation.
Steps:
- Save progress: Add current LoRA changes to the accumulated total
- Fresh start: Reset LoRA matrices (A gets new random values, B starts at zero)
- Reset counter: Start counting steps from 0 again
After this, training continues normally for another cycle. The accumulated changes are preserved and will be included in the final output.
ShouldRestart()
Checks if a restart should be performed based on the current step count.
public bool ShouldRestart()
Returns
- bool
True if current step has reached the restart interval.
Remarks
For Beginners: This checks if it's time for a restart. Returns true when we've completed a full training cycle (reached the interval).
UpdateParameters(T)
Updates parameters with optional warmup after restarts.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe base learning rate for parameter updates.
Remarks
If warmup is enabled, the learning rate is scaled down for the first few steps after each restart to prevent instability. Warmup schedule: lr = base_lr * (current_step / warmup_steps)
For Beginners: This updates the model's parameters using gradients.
After a restart, if warmup is enabled:
- First few steps use a reduced learning rate (gradually increasing)
- This helps the model stabilize after the restart shock
- After warmup, normal learning rate is used
Think of it like easing back into training after a checkpoint.
UpdateParametersFromLayers()
Updates the parameter vector from the current layer states.
protected override void UpdateParametersFromLayers()