Class ScaffoldHeterogeneityCorrection<T>
- Namespace
- AiDotNet.FederatedLearning.Heterogeneity
- Assembly
- AiDotNet.dll
SCAFFOLD-style heterogeneity correction using control variates.
public sealed class ScaffoldHeterogeneityCorrection<T> : FederatedHeterogeneityCorrectionBase<T>, IFederatedHeterogeneityCorrection<T>
Type Parameters
TNumeric type.
- Inheritance
-
ScaffoldHeterogeneityCorrection<T>
- Implements
- Inherited Members
Remarks
For Beginners: SCAFFOLD reduces client drift by tracking "control variates" that estimate how each client's local training differs from the global direction.
Constructors
ScaffoldHeterogeneityCorrection(double)
public ScaffoldHeterogeneityCorrection(double clientLearningRate = 1)
Parameters
clientLearningRatedouble
Methods
Correct(int, int, Vector<T>, Vector<T>, int)
Returns corrected client parameters to be used for aggregation.
public override Vector<T> Correct(int clientId, int roundNumber, Vector<T> globalParameters, Vector<T> localParameters, int localEpochs)
Parameters
clientIdintClient identifier.
roundNumberintRound number (0-indexed).
globalParametersVector<T>Global parameter vector at the start of the round.
localParametersVector<T>Client-trained parameter vector.
localEpochsintLocal epochs used for training (proxy for local steps in simulation).
Returns
- Vector<T>
Corrected parameters.
GetCorrectionName()
Gets the name of the correction method.
public override string GetCorrectionName()