Table of Contents

Class FedBNAggregationStrategy<T>

Namespace
AiDotNet.FederatedLearning.Aggregators
Assembly
AiDotNet.dll

Implements the Federated Batch Normalization (FedBN) aggregation strategy.

public class FedBNAggregationStrategy<T> : ParameterDictionaryAggregationStrategyBase<T>, IAggregationStrategy<Dictionary<string, T[]>>

Type Parameters

T

The numeric type for model parameters (e.g., double, float).

Inheritance
FedBNAggregationStrategy<T>
Implements
Inherited Members

Remarks

FedBN is a specialized aggregation strategy that handles batch normalization layers differently from other layers in neural networks. Proposed by Li et al. in 2021, it addresses the challenge of non-IID data by keeping batch normalization parameters local.

For Beginners: FedBN recognizes that some parts of a neural network should remain personalized to each client rather than being averaged globally.

The key insight:

  • Batch Normalization (BN) layers learn statistics specific to each client's data
  • Averaging BN parameters across clients with different data distributions hurts performance
  • Solution: Keep BN layers local, only aggregate other layers (Conv, FC, etc.)

How FedBN works:

  1. During aggregation, identify batch normalization layers
  2. Aggregate only non-BN layers using weighted averaging
  3. Keep each client's BN layers unchanged (personalized)
  4. Send back global model with client-specific BN layers

For example, in a CNN with layers:

  • Conv1 (filters) → BN1 (normalization) → ReLU → Conv2 → BN2 → FC (classification)

FedBN aggregates:

  • ✓ Conv1 filters: Averaged across clients
  • ✗ BN1 params: Kept local to each client
  • ✓ Conv2 filters: Averaged across clients
  • ✗ BN2 params: Kept local to each client
  • ✓ FC weights: Averaged across clients

Why this matters:

  • Different clients may have different data ranges, distributions
  • Hospital A images: brightness range [0, 100]
  • Hospital B images: brightness range [50, 200]
  • Each needs different normalization parameters
  • Shared feature extractors (Conv layers) + personalized normalization works better

When to use FedBN:

  • Training deep neural networks (especially CNNs)
  • Non-IID data with distribution shift
  • Batch normalization or layer normalization in architecture
  • Want to improve accuracy without changing training much

Benefits:

  • Significantly improves accuracy on non-IID data
  • Simple modification to FedAvg
  • No additional communication cost
  • Each client keeps personalized normalization

Limitations:

  • Only helps when using batch normalization
  • Doesn't address other heterogeneity challenges
  • Requires identifying BN layers in model structure

Reference: Li, X., et al. (2021). "Federated Learning on Non-IID Data Silos: An Experimental Study." ICDE 2021.

Constructors

FedBNAggregationStrategy(HashSet<string>?)

Initializes a new instance of the FedBNAggregationStrategy<T> class.

public FedBNAggregationStrategy(HashSet<string>? batchNormLayerPatterns = null)

Parameters

batchNormLayerPatterns HashSet<string>

Patterns to identify batch normalization layers. If null, uses default patterns.

Remarks

For Beginners: Creates a FedBN aggregator that knows how to identify batch normalization layers in your model.

Common BN layer naming patterns:

  • "bn", "batchnorm", "batch_norm": Explicit BN layers
  • "gamma", "beta": BN trainable parameters
  • "running_mean", "running_var": BN statistics

The strategy will exclude these from aggregation, keeping them client-specific.

Methods

Aggregate(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)

Aggregates client models while keeping batch normalization layers local.

public override Dictionary<string, T[]> Aggregate(Dictionary<int, Dictionary<string, T[]>> clientModels, Dictionary<int, double> clientWeights)

Parameters

clientModels Dictionary<int, Dictionary<string, T[]>>

Dictionary mapping client IDs to their model parameters.

clientWeights Dictionary<int, double>

Dictionary mapping client IDs to their sample counts (weights).

Returns

Dictionary<string, T[]>

The aggregated global model parameters with BN layers excluded from aggregation.

Remarks

This method implements selective aggregation:

For Beginners: Think of this as a smart averaging that knows some parameters should stay personal (like BN layers) while others should be shared (like Conv layers).

Step-by-step process:

  1. For each layer in the model:
    • Check if it's a batch normalization layer (by name matching)
    • If BN: Keep first client's values (could be any client's, they stay local)
    • If not BN: Compute weighted average across all clients
  2. Return aggregated model

Mathematical formulation: For non-BN layers: w_global[layer] = Σ(n_k / n_total) × w_k[layer]

For BN layers: w_global[layer] = w_client[layer] (keeps local values)

For example, with 3 clients and a model with:

  • "conv1": [0.5, 0.6, 0.7] at clients → Average these
  • "bn1_gamma": [1.0, 1.2, 0.9] at clients → Keep local (don't average)
  • "conv2": [0.3, 0.4, 0.5] at clients → Average these
  • "bn2_beta": [0.1, 0.2, 0.15] at clients → Keep local (don't average)

Note: In practice, each client would maintain their own BN parameters. The "global" model returned includes BN params that each client will replace with their local version upon receiving the update.

GetBatchNormPatterns()

Gets the batch normalization layer patterns used for identification.

public IReadOnlyCollection<string> GetBatchNormPatterns()

Returns

IReadOnlyCollection<string>

A set of BN layer patterns.

Remarks

For Beginners: Returns the list of patterns used to recognize which layers are batch normalization layers.

GetStrategyName()

Gets the name of the aggregation strategy.

public override string GetStrategyName()

Returns

string

The string "FedBN".