Table of Contents

Class FedProxAggregationStrategy<T>

Namespace
AiDotNet.FederatedLearning.Aggregators
Assembly
AiDotNet.dll

Implements the Federated Proximal (FedProx) aggregation strategy.

public class FedProxAggregationStrategy<T> : ParameterDictionaryAggregationStrategyBase<T>, IAggregationStrategy<Dictionary<string, T[]>>

Type Parameters

T

The numeric type for model parameters (e.g., double, float).

Inheritance
FedProxAggregationStrategy<T>
Implements
Inherited Members

Remarks

FedProx is an extension of FedAvg that handles system and statistical heterogeneity in federated learning. It was proposed by Li et al. in 2020 to address challenges when clients have different computational capabilities or data distributions.

For Beginners: FedProx is like FedAvg with a "safety rope" that prevents individual clients from pulling the shared model too far in their own direction.

Key differences from FedAvg:

  1. Adds a proximal term to local training objective
  2. Prevents client models from deviating too much from global model
  3. Improves convergence when clients have heterogeneous data or capabilities

How FedProx works: During local training, each client minimizes: Local Loss + (μ/2) × ||w - w_global||²

where:

  • Local Loss: Standard loss on client's data
  • μ (mu): Proximal term coefficient (controls constraint strength)
  • w: Client's current model weights
  • w_global: Global model weights received from server
  • ||w - w_global||²: Squared distance between client and global model

For example, with μ = 0.01:

  • Client trains on local data
  • Proximal term penalizes large deviations from global model
  • If client's data is very different, can still adapt but with limitation
  • Prevents overfitting to local data distribution

When to use FedProx:

  • Non-IID data (different distributions across clients)
  • System heterogeneity (some clients much slower/faster)
  • Want more stable convergence than FedAvg
  • Stragglers problem (some clients take much longer)

Benefits:

  • Better convergence on non-IID data
  • More robust to stragglers
  • Theoretically proven convergence guarantees
  • Small computational overhead

Limitations:

  • Requires tuning μ parameter
  • Slightly slower local training per iteration
  • May converge slower if μ is too large

Reference: Li, T., et al. (2020). "Federated Optimization in Heterogeneous Networks." MLSys 2020.

Constructors

FedProxAggregationStrategy(double)

Initializes a new instance of the FedProxAggregationStrategy<T> class.

public FedProxAggregationStrategy(double mu = 0.01)

Parameters

mu double

The proximal term coefficient (typically 0.01 to 1.0).

Remarks

For Beginners: Creates a FedProx aggregator with a specified proximal term strength.

The μ (mu) parameter controls the trade-off between local adaptation and global consistency:

  • μ = 0: Equivalent to FedAvg (no constraint)
  • μ = 0.01: Mild constraint (recommended starting point)
  • μ = 0.1: Moderate constraint
  • μ = 1.0+: Strong constraint (may be too restrictive)

Recommendations:

  • Start with μ = 0.01
  • Increase if convergence is unstable
  • Decrease if convergence is too slow

Methods

Aggregate(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)

Aggregates client models using FedProx weighted averaging.

public override Dictionary<string, T[]> Aggregate(Dictionary<int, Dictionary<string, T[]>> clientModels, Dictionary<int, double> clientWeights)

Parameters

clientModels Dictionary<int, Dictionary<string, T[]>>

Dictionary mapping client IDs to their model parameters.

clientWeights Dictionary<int, double>

Dictionary mapping client IDs to their sample counts (weights).

Returns

Dictionary<string, T[]>

The aggregated global model parameters.

Remarks

The aggregation step in FedProx is identical to FedAvg. The key difference is in the local training objective (which includes the proximal term), not in aggregation.

For Beginners: At the server side, FedProx aggregates just like FedAvg. The magic happens during client-side training where the proximal term keeps client models from straying too far.

Aggregation formula (same as FedAvg): w_global = Σ(n_k / n_total) × w_k

The proximal term μ affects how w_k is computed during local training, but not how we aggregate the models here.

For implementation in local training (not shown here):

  • Gradient = ∇Loss + μ(w - w_global)
  • This additional term pulls weights towards global model

GetMu()

Gets the proximal term coefficient μ.

public double GetMu()

Returns

double

The μ parameter value.

Remarks

For Beginners: Returns the strength of the constraint that keeps client models from deviating too far from the global model.

GetStrategyName()

Gets the name of the aggregation strategy.

public override string GetStrategyName()

Returns

string

A string indicating "FedProx" with the μ parameter value.