Table of Contents

Class ReLoRAAdapter<T>

Namespace
AiDotNet.LoRA.Adapters
Assembly
AiDotNet.dll

Restart LoRA (ReLoRA) adapter that periodically merges and restarts LoRA training for continual learning.

public class ReLoRAAdapter<T> : LoRAAdapterBase<T>, IDisposable, ILoRAAdapter<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
ReLoRAAdapter<T>
Implements
Inherited Members

Remarks

ReLoRA addresses the challenge of continual learning and long-running training by periodically: 1. Merging the LoRA weights into the base layer (accumulating the adaptation) 2. Resetting the LoRA matrices to restart training fresh 3. Continuing training with a clean slate while preserving previous learning

This approach: - Prevents catastrophic forgetting by accumulating adaptations into the base layer - Allows continuous adaptation to new data without losing old knowledge - Maintains parameter efficiency by resetting LoRA to small matrices - Enables training on continuously evolving data streams

For Beginners: ReLoRA is like having multiple rounds of LoRA training.

Imagine you're fine-tuning a model on data that keeps changing:

  • Round 1: Train LoRA on dataset A for 1000 steps
  • Merge: Add the learned changes into the base model
  • Restart: Reset LoRA matrices and train on dataset B for 1000 steps
  • Merge: Add these new changes to the (already updated) base model
  • Repeat...

Benefits:

  • Continual learning: Can keep learning from new data indefinitely
  • No catastrophic forgetting: Old knowledge is preserved in the base layer
  • Parameter efficient: LoRA matrices stay small even after many restarts
  • Flexible: Can adapt to distribution shifts and new tasks

How it works:

  1. Train normally with LoRA for N steps (restart interval)
  2. At step N: Merge LoRA weights → AccumulatedWeight += LoRA
  3. Reset LoRA matrices to zero (fresh start)
  4. Continue training for another N steps
  5. Repeat indefinitely

Use cases:

  • Training on streaming data (news articles, user behavior, etc.)
  • Adapting to distribution shifts over time
  • Long-running training sessions that need checkpoints
  • Multi-task learning with periodic task switches

Reference: "ReLoRA: High-Rank Training Through Low-Rank Updates" (2023) https://arxiv.org/abs/2307.05695

Constructors

ReLoRAAdapter(ILayer<T>, int, double, int, bool, bool, int)

Initializes a new ReLoRA adapter with restart-based continual learning.

public ReLoRAAdapter(ILayer<T> baseLayer, int rank, double alpha = -1, int restartInterval = 1000, bool freezeBaseLayer = true, bool useWarmup = true, int warmupSteps = 10)

Parameters

baseLayer ILayer<T>

The layer to adapt with ReLoRA.

rank int

The rank of the LoRA decomposition.

alpha double

The LoRA scaling factor (defaults to rank if negative).

restartInterval int

Number of steps between restart operations (default: 1000).

freezeBaseLayer bool

Whether to freeze the base layer's parameters during training (default: true).

useWarmup bool

Whether to use warmup after restarts (default: true).

warmupSteps int

Number of warmup steps after restart (default: 10).

Remarks

For Beginners: This creates a ReLoRA adapter for continual learning.

Parameters:

  • baseLayer: The layer you want to adapt continuously
  • rank: Size of the LoRA matrices (lower = more efficient)
  • alpha: Strength of the LoRA adaptation
  • restartInterval: How often to merge and restart (in training steps)
  • freezeBaseLayer: Lock the base layer weights (typical for LoRA)
  • useWarmup: Use reduced learning rate after restarts (helps stability)
  • warmupSteps: How many steps to warm up for

The adapter will automatically handle merging and restarting at the specified interval. You just train normally, and it takes care of the restart logic.

Exceptions

ArgumentNullException

Thrown when baseLayer is null.

ArgumentException

Thrown when restartInterval is invalid.

Properties

CurrentStep

Gets the current step within the current restart cycle.

public int CurrentStep { get; }

Property Value

int

RestartCount

Gets the total number of restarts that have occurred.

public int RestartCount { get; }

Property Value

int

RestartInterval

Gets the number of training steps between restarts.

public int RestartInterval { get; }

Property Value

int

Methods

Backward(Tensor<T>)

Performs the backward pass through all components.

public override Tensor<T> Backward(Tensor<T> outputGradient)

Parameters

outputGradient Tensor<T>

Gradient flowing back from the next layer.

Returns

Tensor<T>

Gradient to pass to the previous layer.

Remarks

Gradients flow through: 1. Current LoRA layer (always updated) 2. Base layer (only if not frozen) 3. Accumulated weights (updated to reflect gradient contribution)

For Beginners: This is where learning happens! Gradients flow backward: - Update current LoRA matrices - Update base layer if not frozen - Note: Accumulated weights are treated as constants during backprop (they only change during restart, not during normal training)

ForceRestart()

Manually triggers a restart (useful for checkpointing or manual control).

public void ForceRestart()

Remarks

For Beginners: This forces an immediate restart, even if the interval hasn't been reached. Useful when you want to checkpoint at specific points (e.g., after completing a task or dataset).

Forward(Tensor<T>)

Performs the forward pass with accumulated LoRA adaptations.

public override Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor.

Returns

Tensor<T>

Base layer output plus accumulated LoRA adaptation plus current LoRA output.

Remarks

The forward pass computes: output = base_layer(input) + input * AccumulatedWeight + lora_layer(input)

For Beginners: This processes input through three components: 1. Base layer: The original layer (may have been adapted in previous cycles) 2. Accumulated weight: All previous LoRA cycles' learning 3. Current LoRA: The current cycle's adaptation

All three are added together to produce the final output.

GetAccumulatedWeight()

Gets a copy of the accumulated weight matrix.

public Matrix<T> GetAccumulatedWeight()

Returns

Matrix<T>

MergeToOriginalLayer()

Merges all accumulated adaptations into the base layer and returns the merged layer.

public override ILayer<T> MergeToOriginalLayer()

Returns

ILayer<T>

A new layer with all ReLoRA adaptations (accumulated + current) merged into the base layer's weights.

Remarks

This merges: 1. All accumulated LoRA weights from previous restart cycles 2. The current LoRA cycle's weights into the base layer's weights to create a final standalone layer.

For Beginners: This "bakes in" all the ReLoRA learning into a regular layer.

This takes:

  • All previous cycles' learning (from accumulated weights)
  • Current cycle's learning (from current LoRA)
  • Base layer weights

And combines them into a single layer that:

  • Works like a normal layer (no special ReLoRA infrastructure needed)
  • Contains all the adapted knowledge
  • Can be deployed for fast inference

ResetState()

Resets the internal state of all layers.

public override void ResetState()

ResetStepCounter()

Resets the step counter without performing a restart (useful for aligning with external training loops).

public void ResetStepCounter()

RestartLoRA()

Performs the restart operation: merges current LoRA weights and reinitializes.

public void RestartLoRA()

Remarks

The restart process: 1. Merge current LoRA weights: W_accumulated += W_A * W_B * scaling 2. Reinitialize LoRA matrices: A gets new random values, B reset to zero 3. Reset step counter to 0 4. Increment restart count

For Beginners: This performs the "checkpoint and restart" operation.

Steps:

  1. Save progress: Add current LoRA changes to the accumulated total
  2. Fresh start: Reset LoRA matrices (A gets new random values, B starts at zero)
  3. Reset counter: Start counting steps from 0 again

After this, training continues normally for another cycle. The accumulated changes are preserved and will be included in the final output.

ShouldRestart()

Checks if a restart should be performed based on the current step count.

public bool ShouldRestart()

Returns

bool

True if current step has reached the restart interval.

Remarks

For Beginners: This checks if it's time for a restart. Returns true when we've completed a full training cycle (reached the interval).

UpdateParameters(T)

Updates parameters with optional warmup after restarts.

public override void UpdateParameters(T learningRate)

Parameters

learningRate T

The base learning rate for parameter updates.

Remarks

If warmup is enabled, the learning rate is scaled down for the first few steps after each restart to prevent instability. Warmup schedule: lr = base_lr * (current_step / warmup_steps)

For Beginners: This updates the model's parameters using gradients.

After a restart, if warmup is enabled:

  • First few steps use a reduced learning rate (gradually increasing)
  • This helps the model stabilize after the restart shock
  • After warmup, normal learning rate is used

Think of it like easing back into training after a checkpoint.

UpdateParametersFromLayers()

Updates the parameter vector from the current layer states.

protected override void UpdateParametersFromLayers()