Class SecureAggregation<T>
- Namespace
- AiDotNet.FederatedLearning.Privacy
- Assembly
- AiDotNet.dll
Implements secure aggregation for federated learning using cryptographic techniques.
public class SecureAggregation<T> : FederatedLearningComponentBase<T>, IDisposable
Type Parameters
TThe numeric type for model parameters (e.g., double, float).
- Inheritance
-
SecureAggregation<T>
- Implements
- Inherited Members
Remarks
Secure aggregation is a cryptographic protocol that allows a server to compute the sum of client updates without seeing individual contributions. Only the final aggregate is visible to the server.
For Beginners: Secure aggregation is like a secret ballot election where votes are counted but individual votes remain private.
How it works (simplified):
- Each client generates pairwise secret keys with other clients
- Clients mask their model updates with these secret keys
- Server receives masked updates: masked_update_i = update_i + Σ(secrets_ij)
- Secret masks cancel out when summing: Σ(masked_update_i) = Σ(update_i)
- Server gets the sum without seeing individual updates
Example with 3 clients:
- Client 1 shares secrets: s₁₂ with Client 2, s₁₃ with Client 3
- Client 2 shares secrets: s₂₁ with Client 1, s₂₃ with Client 3
- Client 3 shares secrets: s₃₁ with Client 1, s₃₂ with Client 2
Note: s₁₂ = -s₂₁ (secrets cancel in pairs)
Client 1 sends: update₁ + s₁₂ + s₁₃ Client 2 sends: update₂ + s₂₁ + s₂₃ Client 3 sends: update₃ + s₃₁ + s₃₂
Server computes sum: (update₁ + s₁₂ + s₁₃) + (update₂ + s₂₁ + s₂₃) + (update₃ + s₃₁ + s₃₂) = update₁ + update₂ + update₃ + (s₁₂ + s₂₁) + (s₁₃ + s₃₁) + (s₂₃ + s₃₂) = update₁ + update₂ + update₃ + 0 + 0 + 0 = Σ(updates) ← Only this is visible to server!
This implementation derives pairwise mask seeds from per-round ephemeral ECDH shared secrets and expands them via HKDF + a deterministic PRG. Pairwise masks cancel in the aggregate as long as all selected clients participate in the round (synchronous, full-participation mode).
Benefits:
- Server cannot see individual client updates
- Protects against honest-but-curious server
- No trusted third party needed
- Computation overhead is reasonable
Limitations:
- Requires coordination between clients
- All (or threshold) clients must participate for masks to cancel
- Dropout handling requires additional mechanisms
- Communication overhead for key exchange
When to use Secure Aggregation:
- Don't fully trust the central server
- Regulatory requirements for data protection
- Want cryptographic privacy guarantees
- Willing to handle additional complexity
Can be combined with differential privacy for stronger protection:
- Secure aggregation: Protects individual updates from server
- Differential privacy: Protects individual data points from anyone
Reference: Bonawitz, K., et al. (2017). "Practical Secure Aggregation for Privacy-Preserving Machine Learning." CCS 2017.
Constructors
SecureAggregation(int, int?)
Initializes a new instance of the SecureAggregation<T> class.
public SecureAggregation(int parameterCount, int? randomSeed = null)
Parameters
parameterCountintThe total number of model parameters to protect.
randomSeedint?Optional random seed for reproducibility.
Remarks
For Beginners: Sets up the secure aggregation protocol for a specific number of model parameters.
In practice, this would involve:
- Secure key exchange between clients
- Authenticated channels
- Agreement on random seed for deterministic mask generation
In this in-memory implementation, pairwise masks are generated per-round and must be re-generated for each round.
Methods
AggregateSecurely(Dictionary<int, Dictionary<string, T[]>>, Dictionary<int, double>)
Aggregates masked updates from all clients, returning a weighted average.
public Dictionary<string, T[]> AggregateSecurely(Dictionary<int, Dictionary<string, T[]>> maskedUpdates, Dictionary<int, double> clientWeights)
Parameters
maskedUpdatesDictionary<int, Dictionary<string, T[]>>Dictionary of client IDs to their masked updates.
clientWeightsDictionary<int, double>Dictionary of client IDs to their aggregation weights.
Returns
- Dictionary<string, T[]>
The securely aggregated model (weighted average if clients pre-weighted their updates before masking).
Remarks
For Beginners: This sums up all the masked updates. Because the secret masks cancel out, the server recovers the sum of the underlying (possibly weighted) client updates without ever seeing any individual update.
To compute a weighted average securely:
- Each client must apply its weight to its update before masking (use MaskUpdate(int, Dictionary<string, T[]>, double)).
- The server then divides the summed masked updates by the total weight.
If you need the raw (un-normalized) sum of updates, use AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>>).
Mathematical property: Σ(masked_update_i) = Σ(update_i + secrets_i) = Σ(update_i) + Σ(secrets_i) = Σ(update_i) + 0 ← secrets cancel = True sum of updates
The server performs this aggregation without ever seeing individual updates!
For example with 2 clients: Client 1 masked: [0.4, 0.0, 0.85] = [0.5, -0.3, 0.8] + [-0.1, 0.3, 0.05] Client 2 masked: [0.7, 0.1, 1.05] = [0.6, 0.4, 1.1] + [0.1, -0.3, -0.05]
Sum of masked: [1.1, 0.1, 1.9] True sum: [0.5, -0.3, 0.8] + [0.6, 0.4, 1.1] = [1.1, 0.1, 1.9] ← Matches! (Note: Secrets [-0.1, 0.3, 0.05] + [0.1, -0.3, -0.05] = [0, 0, 0] ← Cancelled)
AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>>)
Aggregates masked updates from all clients, returning the raw sum with masks cancelled.
public Dictionary<string, T[]> AggregateSumSecurely(Dictionary<int, Dictionary<string, T[]>> maskedUpdates)
Parameters
maskedUpdatesDictionary<int, Dictionary<string, T[]>>Dictionary of client IDs to their masked updates.
Returns
- Dictionary<string, T[]>
The securely aggregated model (sum of underlying updates with masks cancelled).
Remarks
This method does not divide by any weight. It returns the sum of the underlying updates after the pairwise masks cancel out.
ClearSecrets()
Clears all stored pairwise secrets.
public void ClearSecrets()
Remarks
For Beginners: Removes all secret keys from memory. Should be called after aggregation is complete for security.
Security best practice:
- Generate fresh secrets for each round
- Clear old secrets to prevent reuse
- Minimize time secrets are stored in memory
Dispose()
Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
public void Dispose()
~SecureAggregation()
protected ~SecureAggregation()
GeneratePairwiseSecrets(List<int>)
Generates pairwise secrets between all clients.
public void GeneratePairwiseSecrets(List<int> clientIds)
Parameters
Remarks
For Beginners: This creates secret keys that clients will use to mask their updates. The secrets are designed so they cancel out when aggregated.
For each pair of clients (i, j):
- Generate random secret s_ij
- Set s_ji = -s_ij (so they cancel: s_ij + s_ji = 0)
In production, this would use:
- Diffie-Hellman key exchange
- Public key infrastructure
- Secure random number generation
GetClientCount()
Gets the number of clients with stored secrets.
public int GetClientCount()
Returns
- int
The count of clients.
MaskUpdate(int, Dictionary<string, T[]>)
Masks a client's model update with pairwise secrets.
public Dictionary<string, T[]> MaskUpdate(int clientId, Dictionary<string, T[]> clientUpdate)
Parameters
clientIdintThe ID of the client whose update to mask.
clientUpdateDictionary<string, T[]>The client's model update.
Returns
- Dictionary<string, T[]>
The masked model update.
Remarks
For Beginners: This adds secret masks to the client's update so the server can't see the original values. Only after aggregating all clients do the masks cancel out.
Mathematical operation: masked_update = original_update + Σ(secrets_with_other_clients)
For example, Client 1 with 3 clients total:
- Original update: [0.5, -0.3, 0.8]
- Secret with Client 2: [0.1, 0.2, -0.1]
- Secret with Client 3: [-0.2, 0.1, 0.15]
- Masked update: [0.4, 0.0, 0.85]
Server sees: [0.4, 0.0, 0.85] ← Cannot recover original [0.5, -0.3, 0.8] But after aggregating all clients, secrets cancel and server gets correct sum.
MaskUpdate(int, Dictionary<string, T[]>, double)
Masks a client's model update with pairwise secrets, applying the client's aggregation weight before masking.
public Dictionary<string, T[]> MaskUpdate(int clientId, Dictionary<string, T[]> clientUpdate, double clientWeight)
Parameters
clientIdintThe ID of the client whose update to mask.
clientUpdateDictionary<string, T[]>The client's (unweighted) model update.
clientWeightdoubleThe aggregation weight to apply to this client's update (e.g., sample count).
Returns
- Dictionary<string, T[]>
The masked (and weighted) model update.
Remarks
For secure weighted averaging, clients must apply weights to their updates before masking so secrets still cancel.
This overload multiplies the update by clientWeight and then adds the pairwise masks.