Class ChainLoRAAdapter<T>
Chain-of-LoRA adapter that implements sequential composition of multiple LoRA adapters.
public class ChainLoRAAdapter<T> : LoRAAdapterBase<T>, IDisposable, ILoRAAdapter<T>, ILayer<T>, IJitCompilable<T>, IDiagnosticsProvider, IWeightLoadable<T>
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
LayerBase<T>ChainLoRAAdapter<T>
- Implements
-
ILoRAAdapter<T>ILayer<T>
- Inherited Members
Remarks
Chain-of-LoRA (COLA) is an advanced LoRA technique that enables sequential composition of multiple LoRA adaptations through an iterative optimization framework. Unlike standard LoRA which applies a single low-rank adaptation, COLA builds a chain of adaptations where each adapter is trained, merged into the model, and then a new adapter is initialized for further refinement.
This approach bridges the performance gap between standard LoRA and full fine-tuning by employing residual learning principles. Each iteration in the chain adds incremental improvements to the model's task-specific performance without incurring additional computational costs or memory overhead during inference.
Key Concepts:
Sequential Adaptation: Chain-of-LoRA applies adaptations in sequence (Task A → Task B → Task C), where each stage builds upon the previous one. This is inspired by the Frank-Wolfe optimization algorithm, which makes greedy updates along the direction of maximum improvement.
Merge and Re-initialize: After training each LoRA adapter, the learned weights are merged back into the base layer, and a new LoRA adapter is initialized. This "tying a knot" process allows the model to consolidate learned knowledge before adding new adaptations.
Knowledge Preservation: By freezing the base layer and only training the LoRA components, the chain preserves previously learned knowledge while allowing new task-specific adaptations. Each adapter in the chain captures a specific aspect of the task or a refinement step.
Incremental Fine-tuning Pipeline: COLA enables continual learning scenarios where tasks are presented sequentially, and the model must adapt to new tasks while maintaining performance on previous ones.
Benefits of Chain-of-LoRA:
- Better Performance: Achieves up to 6.47% relative accuracy gain over standard LoRA
- No Extra Overhead: After merging, inference cost is identical to the base model
- Modular Adaptation: Each adapter can be trained, tested, and validated independently
- Catastrophic Forgetting Mitigation: Sequential merging helps preserve prior knowledge
- Task Chaining: Naturally supports multi-task learning and transfer learning scenarios
- Flexible Deployment: Can deploy the full chain or selected adapters as needed
For Beginners:
Imagine you're learning a complex skill in stages:
- First, you learn the basics (Adapter 1)
- Then you practice and the basics become automatic (Merge)
- Next, you learn intermediate techniques on top of the basics (Adapter 2)
- Again, you practice until they're automatic (Merge)
- Finally, you learn advanced skills building on everything before (Adapter 3)
Chain-of-LoRA works the same way: each adapter learns something new, then it's consolidated into the model, and the next adapter can focus on the next refinement. This stepwise approach often achieves better results than trying to learn everything at once.
Research Reference:
Based on "Chain of LoRA: Efficient Fine-tuning of Language Models via Residual Learning" (arXiv:2401.04151, January 2024). The paper demonstrates that sequential low-rank adaptations can significantly improve task performance compared to single-stage LoRA, especially on complex reasoning and multi-step tasks.
Usage Example:
// Create a chain with 3 sequential adaptations
var chain = new ChainLoRAAdapter<double>(baseLayer, rank: 8, chainLength: 3);
// Train first adapter on Task A
chain.SetActiveAdapterIndex(0);
TrainModel(chain, taskAData);
chain.FreezeActiveAdapter(); // Freeze Task A adapter
// Train second adapter on Task B
chain.SetActiveAdapterIndex(1);
TrainModel(chain, taskBData);
chain.FreezeActiveAdapter(); // Freeze Task B adapter
// Train third adapter on Task C
chain.SetActiveAdapterIndex(2);
TrainModel(chain, taskCData);
// Deploy: merge all adapters into base layer for optimized inference
ILayer<double> finalLayer = chain.MergeToOriginalLayer();
Constructors
ChainLoRAAdapter(ILayer<T>, int, int, double, bool)
Initializes a new Chain-of-LoRA adapter with the specified configuration.
public ChainLoRAAdapter(ILayer<T> baseLayer, int rank, int chainLength = 3, double alpha = -1, bool freezeBaseLayer = true)
Parameters
baseLayerILayer<T>The layer to adapt with the LoRA chain.
rankintThe rank of each LoRA decomposition in the chain.
chainLengthintThe number of sequential adapters in the chain (default: 3).
alphadoubleThe LoRA scaling factor for each adapter (defaults to rank if negative).
freezeBaseLayerboolWhether to freeze the base layer's parameters during training (default: true).
Remarks
Creates a chain of LoRA adapters for sequential fine-tuning. Each adapter in the chain can be trained independently, merged into the model, and then the next adapter can be activated for further refinement.
For Beginners:
Parameters:
- baseLayer: The layer you want to adapt (e.g., a dense or convolutional layer)
- rank: How compressed each adapter is (lower = fewer parameters per stage)
- chainLength: How many sequential adaptation stages you want (typical: 2-5)
- alpha: Controls adaptation strength (usually equals rank)
- freezeBaseLayer: Lock base weights to preserve pre-trained knowledge (recommended: true)
Example: chainLength=3 means you can do three rounds of training and merging, allowing the model to incrementally improve on complex tasks.
Exceptions
- ArgumentNullException
Thrown when baseLayer is null.
- ArgumentException
Thrown when chainLength is less than 1.
Properties
ActiveAdapterIndex
Gets the index of the currently active adapter (0-based).
public int ActiveAdapterIndex { get; }
Property Value
Remarks
The active adapter is the one currently being trained. Other adapters in the chain are either waiting to be trained (higher indices) or have been merged (lower indices).
AdapterChain
Gets the list of LoRA adapters in the chain.
public IReadOnlyList<LoRALayer<T>> AdapterChain { get; }
Property Value
Remarks
Each adapter in the chain represents one stage of sequential adaptation. Adapters are applied in order during forward passes.
ChainLength
Gets the total number of adapters in the chain.
public int ChainLength { get; }
Property Value
Remarks
This represents the maximum number of sequential adaptation stages that can be applied. Each adapter can be trained independently and then merged before proceeding to the next.
FrozenStatus
Gets the frozen status of each adapter in the chain.
public IReadOnlyList<bool> FrozenStatus { get; }
Property Value
Remarks
True indicates that an adapter has been frozen and should no longer contribute trainable parameters. Frozen adapters still contribute to forward/backward passes until the entire chain is merged via MergeToOriginalLayer().
ParameterCount
Gets the total number of parameters in the chain (base layer + all unfrozen adapters).
public override int ParameterCount { get; }
Property Value
Remarks
This count includes parameters from the base layer (if not frozen) plus all unfrozen adapters in the chain. Frozen adapters don't contribute to the parameter count since they no longer receive gradient updates. Returns the cached _currentParameterCount once the chain is initialized, or computes it on-the-fly during construction to handle base class initialization.
Methods
Backward(Tensor<T>)
Performs the backward pass through all layers in the chain.
public override Tensor<T> Backward(Tensor<T> outputGradient)
Parameters
outputGradientTensor<T>Gradient flowing back from the next layer.
Returns
- Tensor<T>
Gradient to pass to the previous layer.
Remarks
Gradients flow through all adapters and the base layer. Only unfrozen adapters and the base layer (if not frozen) receive parameter updates.
For Beginners: During learning, this figures out how to improve each adapter. Only the active, unfrozen adapter gets updated - the frozen ones preserve their learned knowledge.
Forward(Tensor<T>)
Performs the forward pass through the base layer and all adapters in the chain.
public override Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Output with all adapter contributions summed.
Remarks
The forward pass computes: output = base_layer(input) + adapter_0(input) + adapter_1(input) + ... + adapter_n(input)
IMPORTANT: All adapters contribute to the output, regardless of frozen status. Frozen adapters continue to be computed in every forward pass. They are only "frozen" in the sense that they don't receive gradient updates during training. True inference optimization (eliminating frozen adapter computation) only occurs after calling MergeToOriginalLayer().
For Beginners: During inference or training, the input goes through the base layer and ALL adapters in the chain (both frozen and unfrozen). Their outputs are added together to get the final result. Freezing an adapter stops it from training, but it still contributes to every prediction.
FreezeActiveAdapter()
Freezes the currently active adapter to prevent further training.
public void FreezeActiveAdapter()
Remarks
This "ties a knot" in the chain by marking the active adapter as frozen. The adapter continues to contribute to forward passes but will no longer receive gradient updates, allowing the next adapter in the chain to build upon this consolidated knowledge.
IMPORTANT: This method does NOT merge weights into the base layer. All adapters (frozen or not) remain active during forward/backward passes. True weight merging only occurs when MergeToOriginalLayer() is called at the end of training.
For Beginners: After training an adapter stage, call this to "lock it in" before moving to the next stage. The adapter's learned knowledge is preserved and it stops training, but it still contributes to the model's output. Think of it like finishing one chapter before starting the next - the previous chapter's knowledge remains active.
GetFrozenCount()
Gets the number of adapters that have been frozen.
public int GetFrozenCount()
Returns
- int
Count of frozen adapters.
GetParameters()
Gets the current parameters as a vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
Vector containing parameters from base layer (if not frozen) and all unfrozen adapters.
GetTrainableAdapterCount()
Gets the number of adapters that are still trainable (not frozen).
public int GetTrainableAdapterCount()
Returns
- int
Count of unfrozen adapters.
MergeToOriginalLayer()
Merges all adapters in the chain into the original base layer.
public override ILayer<T> MergeToOriginalLayer()
Returns
- ILayer<T>
A new layer with all LoRA adaptations merged into the base weights.
Remarks
This creates a single layer that includes all the sequential adaptations from the chain. The resulting layer has the same computational cost as the original base layer but includes all the learned improvements from each stage of the chain.
For Beginners: After training all stages of the chain, call this to create a final optimized layer. The result is a regular layer (no LoRA overhead) that performs as well as the full chain. Perfect for deployment when you want maximum speed with all the learned adaptations.
Exceptions
- InvalidOperationException
Thrown when the base layer type is not DenseLayer or FullyConnectedLayer.
ResetState()
Resets the internal state of the base layer and all adapters in the chain.
public override void ResetState()
SetActiveAdapterIndex(int)
Sets which adapter in the chain is currently active for training.
public void SetActiveAdapterIndex(int index)
Parameters
indexintThe 0-based index of the adapter to activate.
Remarks
Only the active adapter receives gradient updates during training. Other adapters are either frozen (already merged) or inactive (waiting to be trained).
For Beginners: This is like choosing which stage of learning you're currently working on. Set to 0 for the first stage, 1 for the second, etc. Only that stage's adapter will be trained while the others remain frozen.
Exceptions
- ArgumentOutOfRangeException
Thrown when index is out of range.
SetParameters(Vector<T>)
Sets the layer parameters from a vector.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>Vector containing parameters.
Exceptions
- ArgumentException
Thrown when parameter count doesn't match.
UnfreezeAdapter(int)
Unfreezes a previously frozen adapter, making it trainable again.
public void UnfreezeAdapter(int index)
Parameters
indexintThe index of the adapter to unfreeze.
Remarks
This allows re-training a previously frozen adapter if needed for iterative refinement. Useful for scenarios where you want to go back and adjust an earlier stage.
Exceptions
- ArgumentOutOfRangeException
Thrown when index is out of range.
UpdateParameters(T)
Updates parameters using the specified learning rate.
public override void UpdateParameters(T learningRate)
Parameters
learningRateTThe learning rate for parameter updates.
Remarks
Only the active unfrozen adapter receives updates. Frozen adapters and the base layer (if frozen) do not receive parameter updates.