Class AggregationStrategyBase<TModel, T>
- Namespace
- AiDotNet.FederatedLearning.Aggregators
- Assembly
- AiDotNet.dll
Base class for federated aggregation strategies.
public abstract class AggregationStrategyBase<TModel, T> : FederatedLearningComponentBase<T>, IAggregationStrategy<TModel>
Type Parameters
TModelModel/update representation.
TNumeric type.
- Inheritance
-
AggregationStrategyBase<TModel, T>
- Implements
-
IAggregationStrategy<TModel>
- Derived
- Inherited Members
Methods
Aggregate(Dictionary<int, TModel>, Dictionary<int, double>)
Aggregates model updates from multiple clients into a single global model update.
public abstract TModel Aggregate(Dictionary<int, TModel> clientModels, Dictionary<int, double> clientWeights)
Parameters
clientModelsDictionary<int, TModel>Dictionary mapping client IDs to their trained models.
clientWeightsDictionary<int, double>Dictionary mapping client IDs to their aggregation weights (typically based on data size).
Returns
- TModel
The aggregated global model.
Remarks
This method combines model updates from clients using the strategy's specific algorithm.
For Beginners: Aggregation is like combining multiple rough drafts of a document into one polished version that incorporates the best parts of each.
The aggregation process typically:
- Takes model updates (weight changes) from each client
- Considers the weight or importance of each client (based on data size, accuracy, etc.)
- Combines these updates using the strategy's algorithm
- Returns a single aggregated model that represents the collective improvement
For example with weighted averaging (FedAvg):
- Client 1 (1000 samples): model update A
- Client 2 (500 samples): model update B
- Client 3 (1500 samples): model update C
- Aggregated update = (1000A + 500B + 1500*C) / 3000
GetStrategyName()
Gets the name of the aggregation strategy.
public abstract string GetStrategyName()
Returns
- string
A string describing the aggregation strategy (e.g., "FedAvg", "FedProx", "Krum").
Remarks
For Beginners: This helps identify which aggregation method is being used, useful for logging, debugging, and comparing different strategies.
GetTotalWeightOrThrow(Dictionary<int, double>, IEnumerable<int>, string)
protected static double GetTotalWeightOrThrow(Dictionary<int, double> clientWeights, IEnumerable<int> clientIds, string paramName)
Parameters
clientWeightsDictionary<int, double>clientIdsIEnumerable<int>paramNamestring