Table of Contents

Class FedAvgAggregationStrategy<T>

Namespace
AiDotNet.FederatedLearning.Aggregators
Assembly
AiDotNet.dll

Implements the Federated Averaging (FedAvg) aggregation strategy.

public class FedAvgAggregationStrategy<T> : ParameterDictionaryAggregationStrategyBase<T>, IAggregationStrategy<Dictionary<string, T[]>>

Type Parameters

T

The numeric type for model parameters (e.g., double, float).

Inheritance
FedAvgAggregationStrategy<T>
Implements
Inherited Members

Remarks

FedAvg is the foundational aggregation algorithm for federated learning, proposed by McMahan et al. in 2017. It performs a weighted average of client model updates based on the number of training samples each client has.

For Beginners: FedAvg is like calculating a weighted class average where students who solved more practice problems have more influence on the final answer.

How FedAvg works:

  1. Each client trains on their local data and computes model updates
  2. Clients send their updated model weights to the server
  3. Server computes weighted average: weight = (client_samples / total_samples)
  4. New global model = Σ(weight_i × client_model_i)

For example, with 3 hospitals:

  • Hospital A: 1000 patients, model accuracy 90%
  • Hospital B: 500 patients, model accuracy 88%
  • Hospital C: 1500 patients, model accuracy 92%

Total patients: 3000 Hospital A weight: 1000/3000 = 0.333 Hospital B weight: 500/3000 = 0.167 Hospital C weight: 1500/3000 = 0.500

For each model parameter: global_param = 0.333 × A_param + 0.167 × B_param + 0.500 × C_param

Benefits:

  • Simple and efficient
  • Well-studied theoretically
  • Works well when clients have similar data distributions (IID data)

Limitations:

  • Assumes clients are equally reliable
  • Can struggle with non-IID data (different distributions across clients)
  • No built-in handling for stragglers (slow clients)

Reference: McMahan, H. B., et al. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS 2017.

Methods

Aggregate(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)

Aggregates client models using weighted averaging based on the number of samples.

public override Dictionary<string, T[]> Aggregate(Dictionary<int, Dictionary<string, T[]>> clientModels, Dictionary<int, double> clientWeights)

Parameters

clientModels Dictionary<int, Dictionary<string, T[]>>

Dictionary mapping client IDs to their model parameters.

clientWeights Dictionary<int, double>

Dictionary mapping client IDs to their sample counts (weights).

Returns

Dictionary<string, T[]>

The aggregated global model parameters.

Remarks

This method implements the core FedAvg algorithm:

Mathematical formulation: w_global = Σ(n_k / n_total) × w_k

where:

  • w_global: global model weights
  • w_k: client k's model weights
  • n_k: number of samples at client k
  • n_total: total samples across all clients

For Beginners: This combines all client models into one by taking a weighted average, where clients with more data have more influence.

Step-by-step process:

  1. Calculate total samples across all clients
  2. For each client, compute weight = client_samples / total_samples
  3. For each model parameter, compute weighted sum
  4. Return the aggregated model

For example, if we have 2 clients with a simple model (one parameter):

  • Client 1: 300 samples, parameter value = 0.8
  • Client 2: 700 samples, parameter value = 0.6

Total samples: 1000 Client 1 weight: 300/1000 = 0.3 Client 2 weight: 700/1000 = 0.7 Aggregated parameter: 0.3 × 0.8 + 0.7 × 0.6 = 0.24 + 0.42 = 0.66

GetStrategyName()

Gets the name of the aggregation strategy.

public override string GetStrategyName()

Returns

string

The string "FedAvg".