Class FedAvgAggregationStrategy<T>
- Namespace
- AiDotNet.FederatedLearning.Aggregators
- Assembly
- AiDotNet.dll
Implements the Federated Averaging (FedAvg) aggregation strategy.
public class FedAvgAggregationStrategy<T> : ParameterDictionaryAggregationStrategyBase<T>, IAggregationStrategy<Dictionary<string, T[]>>
Type Parameters
TThe numeric type for model parameters (e.g., double, float).
- Inheritance
-
FedAvgAggregationStrategy<T>
- Implements
- Inherited Members
Remarks
FedAvg is the foundational aggregation algorithm for federated learning, proposed by McMahan et al. in 2017. It performs a weighted average of client model updates based on the number of training samples each client has.
For Beginners: FedAvg is like calculating a weighted class average where students who solved more practice problems have more influence on the final answer.
How FedAvg works:
- Each client trains on their local data and computes model updates
- Clients send their updated model weights to the server
- Server computes weighted average: weight = (client_samples / total_samples)
- New global model = Σ(weight_i × client_model_i)
For example, with 3 hospitals:
- Hospital A: 1000 patients, model accuracy 90%
- Hospital B: 500 patients, model accuracy 88%
- Hospital C: 1500 patients, model accuracy 92%
Total patients: 3000 Hospital A weight: 1000/3000 = 0.333 Hospital B weight: 500/3000 = 0.167 Hospital C weight: 1500/3000 = 0.500
For each model parameter: global_param = 0.333 × A_param + 0.167 × B_param + 0.500 × C_param
Benefits:
- Simple and efficient
- Well-studied theoretically
- Works well when clients have similar data distributions (IID data)
Limitations:
- Assumes clients are equally reliable
- Can struggle with non-IID data (different distributions across clients)
- No built-in handling for stragglers (slow clients)
Reference: McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS 2017.
Methods
Aggregate(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)
Aggregates client models using weighted averaging based on the number of samples.
public override Dictionary<string, T[]> Aggregate(Dictionary<int, Dictionary<string, T[]>> clientModels, Dictionary<int, double> clientWeights)
Parameters
clientModelsDictionary<int, Dictionary<string, T[]>>Dictionary mapping client IDs to their model parameters.
clientWeightsDictionary<int, double>Dictionary mapping client IDs to their sample counts (weights).
Returns
- Dictionary<string, T[]>
The aggregated global model parameters.
Remarks
This method implements the core FedAvg algorithm:
Mathematical formulation: w_global = Σ(n_k / n_total) × w_k
where:
- w_global: global model weights
- w_k: client k's model weights
- n_k: number of samples at client k
- n_total: total samples across all clients
For Beginners: This combines all client models into one by taking a weighted average, where clients with more data have more influence.
Step-by-step process:
- Calculate total samples across all clients
- For each client, compute weight = client_samples / total_samples
- For each model parameter, compute weighted sum
- Return the aggregated model
For example, if we have 2 clients with a simple model (one parameter):
- Client 1: 300 samples, parameter value = 0.8
- Client 2: 700 samples, parameter value = 0.6
Total samples: 1000 Client 1 weight: 300/1000 = 0.3 Client 2 weight: 700/1000 = 0.7 Aggregated parameter: 0.3 × 0.8 + 0.7 × 0.6 = 0.24 + 0.42 = 0.66
GetStrategyName()
Gets the name of the aggregation strategy.
public override string GetStrategyName()
Returns
- string
The string "FedAvg".