Table of Contents

Class SecureAggregation<T>

Namespace
AiDotNet.FederatedLearning.Privacy
Assembly
AiDotNet.dll

Implements secure aggregation for federated learning using cryptographic techniques.

public class SecureAggregation<T> : FederatedLearningComponentBase<T>, IDisposable

Type Parameters

T

The numeric type for model parameters (e.g., double, float).

Inheritance
SecureAggregation<T>
Implements
Inherited Members

Remarks

Secure aggregation is a cryptographic protocol that allows a server to compute the sum of client updates without seeing individual contributions. Only the final aggregate is visible to the server.

For Beginners: Secure aggregation is like a secret ballot election where votes are counted but individual votes remain private.

How it works (simplified):

  1. Each client generates pairwise secret keys with other clients
  2. Clients mask their model updates with these secret keys
  3. Server receives masked updates: masked_update_i = update_i + Σ(secrets_ij)
  4. Secret masks cancel out when summing: Σ(masked_update_i) = Σ(update_i)
  5. Server gets the sum without seeing individual updates

Example with 3 clients:

  • Client 1 shares secrets: s₁₂ with Client 2, s₁₃ with Client 3
  • Client 2 shares secrets: s₂₁ with Client 1, s₂₃ with Client 3
  • Client 3 shares secrets: s₃₁ with Client 1, s₃₂ with Client 2

Note: s₁₂ = -s₂₁ (secrets cancel in pairs)

Client 1 sends: update₁ + s₁₂ + s₁₃ Client 2 sends: update₂ + s₂₁ + s₂₃ Client 3 sends: update₃ + s₃₁ + s₃₂

Server computes sum: (update₁ + s₁₂ + s₁₃) + (update₂ + s₂₁ + s₂₃) + (update₃ + s₃₁ + s₃₂) = update₁ + update₂ + update₃ + (s₁₂ + s₂₁) + (s₁₃ + s₃₁) + (s₂₃ + s₃₂) = update₁ + update₂ + update₃ + 0 + 0 + 0 = Σ(updates) ← Only this is visible to server!

This implementation derives pairwise mask seeds from per-round ephemeral ECDH shared secrets and expands them via HKDF + a deterministic PRG. Pairwise masks cancel in the aggregate as long as all selected clients participate in the round (synchronous, full-participation mode).

Benefits:

  • Server cannot see individual client updates
  • Protects against honest-but-curious server
  • No trusted third party needed
  • Computation overhead is reasonable

Limitations:

  • Requires coordination between clients
  • All (or threshold) clients must participate for masks to cancel
  • Dropout handling requires additional mechanisms
  • Communication overhead for key exchange

When to use Secure Aggregation:

  • Don't fully trust the central server
  • Regulatory requirements for data protection
  • Want cryptographic privacy guarantees
  • Willing to handle additional complexity

Can be combined with differential privacy for stronger protection:

  • Secure aggregation: Protects individual updates from server
  • Differential privacy: Protects individual data points from anyone

Reference: Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving Machine Learning." CCS 2017.

Constructors

SecureAggregation(int, int?)

Initializes a new instance of the SecureAggregation<T> class.

public SecureAggregation(int parameterCount, int? randomSeed = null)

Parameters

parameterCount int

The total number of model parameters to protect.

randomSeed int?

Optional random seed for reproducibility.

Remarks

For Beginners: Sets up the secure aggregation protocol for a specific number of model parameters.

In practice, this would involve:

  • Secure key exchange between clients
  • Authenticated channels
  • Agreement on random seed for deterministic mask generation

In this in-memory implementation, pairwise masks are generated per-round and must be re-generated for each round.

Methods

AggregateSecurely(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)

Aggregates masked updates from all clients, returning a weighted average.

public Dictionary<string, T[]> AggregateSecurely(Dictionary<int, Dictionary<string, T[]>> maskedUpdates, Dictionary<int, double> clientWeights)

Parameters

maskedUpdates Dictionary<int, Dictionary<string, T[]>>

Dictionary of client IDs to their masked updates.

clientWeights Dictionary<int, double>

Dictionary of client IDs to their aggregation weights.

Returns

Dictionary<string, T[]>

The securely aggregated model (weighted average if clients pre-weighted their updates before masking).

Remarks

For Beginners: This sums up all the masked updates. Because the secret masks cancel out, the server recovers the sum of the underlying (possibly weighted) client updates without ever seeing any individual update.

To compute a weighted average securely:

If you need the raw (un-normalized) sum of updates, use AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>>).

Mathematical property: Σ(masked_update_i) = Σ(update_i + secrets_i) = Σ(update_i) + Σ(secrets_i) = Σ(update_i) + 0 ← secrets cancel = True sum of updates

The server performs this aggregation without ever seeing individual updates!

For example with 2 clients: Client 1 masked: [0.4, 0.0, 0.85] = [0.5, -0.3, 0.8] + [-0.1, 0.3, 0.05] Client 2 masked: [0.7, 0.1, 1.05] = [0.6, 0.4, 1.1] + [0.1, -0.3, -0.05]

Sum of masked: [1.1, 0.1, 1.9] True sum: [0.5, -0.3, 0.8] + [0.6, 0.4, 1.1] = [1.1, 0.1, 1.9] ← Matches! (Note: Secrets [-0.1, 0.3, 0.05] + [0.1, -0.3, -0.05] = [0, 0, 0] ← Cancelled)

AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>>)

Aggregates masked updates from all clients, returning the raw sum with masks cancelled.

public Dictionary<string, T[]> AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>> maskedUpdates)

Parameters

maskedUpdates Dictionary<int, Dictionary<string, T[]>>

Dictionary of client IDs to their masked updates.

Returns

Dictionary<string, T[]>

The securely aggregated model (sum of underlying updates with masks cancelled).

Remarks

This method does not divide by any weight. It returns the sum of the underlying updates after the pairwise masks cancel out.

ClearSecrets()

Clears all stored pairwise secrets.

public void ClearSecrets()

Remarks

For Beginners: Removes all secret keys from memory. Should be called after aggregation is complete for security.

Security best practice:

  • Generate fresh secrets for each round
  • Clear old secrets to prevent reuse
  • Minimize time secrets are stored in memory

Dispose()

Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.

public void Dispose()

~SecureAggregation()

protected ~SecureAggregation()

GeneratePairwiseSecrets(List<int>)

Generates pairwise secrets between all clients.

public void GeneratePairwiseSecrets(List<int> clientIds)

Parameters

clientIds List<int>

List of all participating client IDs.

Remarks

For Beginners: This creates secret keys that clients will use to mask their updates. The secrets are designed so they cancel out when aggregated.

For each pair of clients (i, j):

  • Generate random secret s_ij
  • Set s_ji = -s_ij (so they cancel: s_ij + s_ji = 0)

In production, this would use:

  • Diffie-Hellman key exchange
  • Public key infrastructure
  • Secure random number generation

GetClientCount()

Gets the number of clients with stored secrets.

public int GetClientCount()

Returns

int

The count of clients.

MaskUpdate(int, Dictionary<string, T[]>)

Masks a client's model update with pairwise secrets.

public Dictionary<string, T[]> MaskUpdate(int clientId, Dictionary<string, T[]> clientUpdate)

Parameters

clientId int

The ID of the client whose update to mask.

clientUpdate Dictionary<string, T[]>

The client's model update.

Returns

Dictionary<string, T[]>

The masked model update.

Remarks

For Beginners: This adds secret masks to the client's update so the server can't see the original values. Only after aggregating all clients do the masks cancel out.

Mathematical operation: masked_update = original_update + Σ(secrets_with_other_clients)

For example, Client 1 with 3 clients total:

  • Original update: [0.5, -0.3, 0.8]
  • Secret with Client 2: [0.1, 0.2, -0.1]
  • Secret with Client 3: [-0.2, 0.1, 0.15]
  • Masked update: [0.4, 0.0, 0.85]

Server sees: [0.4, 0.0, 0.85] ← Cannot recover original [0.5, -0.3, 0.8] But after aggregating all clients, secrets cancel and server gets correct sum.

MaskUpdate(int, Dictionary<string, T[]>, double)

Masks a client's model update with pairwise secrets, applying the client's aggregation weight before masking.

public Dictionary<string, T[]> MaskUpdate(int clientId, Dictionary<string, T[]> clientUpdate, double clientWeight)

Parameters

clientId int

The ID of the client whose update to mask.

clientUpdate Dictionary<string, T[]>

The client's (unweighted) model update.

clientWeight double

The aggregation weight to apply to this client's update (e.g., sample count).

Returns

Dictionary<string, T[]>

The masked (and weighted) model update.

Remarks

For secure weighted averaging, clients must apply weights to their updates before masking so secrets still cancel. This overload multiplies the update by clientWeight and then adds the pairwise masks.