Interface IClientModel<TData, TUpdate>
- Namespace
- AiDotNet.Interfaces
- Assembly
- AiDotNet.dll
Defines the functionality for a client-side model in federated learning.
public interface IClientModel<TData, TUpdate>
Type Parameters
TDataThe type of the local training data.
TUpdateThe type of the model update to send to the server.
Remarks
This interface represents a model that exists on a client device or node in a federated learning system. Each client maintains its own copy of the global model and trains it on local data.
For Beginners: A client model is like a student's personal copy of study materials. Each student (client) has their own copy, studies it with their own resources, and contributes improvements back to the class.
Think of client models as distributed learners:
- Each client has a copy of the global model
- Clients train on their own private data
- Local training happens independently and in parallel
- Only model updates (not data) are sent to the server
For example, in smartphone keyboard prediction:
- Each phone has a copy of the global typing prediction model
- The phone learns from the user's typing patterns
- It sends model improvements (not actual typed text) to the server
- The server combines improvements from millions of phones
- Each phone gets the improved model back
This design ensures:
- Data privacy: Raw data never leaves the client
- Personalization: Can adapt to local data distribution
- Scalability: Training happens in parallel across all clients
Methods
GetModelUpdate()
Computes and retrieves the model update to send to the server.
TUpdate GetModelUpdate()
Returns
- TUpdate
The model update containing weight changes or gradients.
Remarks
The model update represents the improvements the client made through local training. This is typically the difference between the current model and the initial global model.
For Beginners: This is like preparing a summary of what you learned from studying, rather than sharing your entire study materials. You share the insights, not the sources.
The update typically contains:
- Weight differences: New weights - original weights
- Gradients: Direction and magnitude of improvement
- Metadata: Number of local samples, local loss, etc.
For example:
- Original weight for feature "age": 0.5
- After training, weight for "age": 0.6
- Update to send: +0.1
- This tells the server how to adjust that weight
GetSampleCount()
Gets the number of training samples available on this client.
int GetSampleCount()
Returns
- int
The number of training samples on this client.
Remarks
Sample count is used to weight client contributions during aggregation. Clients with more data typically receive higher weights.
For Beginners: This is like indicating how many practice problems you solved. If you solved 1000 problems and someone else solved 100, your insights about problem-solving patterns are likely more reliable.
TrainLocal(TData, int, double)
Trains the local model on the client's private data.
void TrainLocal(TData localData, int epochs, double learningRate)
Parameters
localDataTDataThe client's private training data.
epochsintNumber of training iterations to perform on local data.
learningRatedoubleStep size for gradient descent optimization.
Remarks
Local training is the core of federated learning where each client improves the model using their own data without sharing that data with anyone.
For Beginners: This is like studying independently with your own materials. You use your personal notes and resources to learn, and later share what you learned, not the actual materials.
The training process:
- Receive the global model from the server
- Train on local data for specified number of epochs
- Compute the difference between updated and original model (the "update")
- Prepare this update to send back to the server
For example:
- Client receives global model with accuracy 80%
- Trains on local data for 5 epochs
- Local model now has accuracy 85% on local data
- Computes weight changes (delta) that improved the model
- Sends these weight changes to server, not the local data
UpdateFromGlobal(TUpdate)
Updates the local model with the new global model from the server.
void UpdateFromGlobal(TUpdate globalModelUpdate)
Parameters
globalModelUpdateTUpdateThe aggregated global model from the server.
Remarks
After the server aggregates updates from all clients, it sends the improved global model back to clients for the next round of training.
For Beginners: This is like receiving the updated textbook that incorporates everyone's contributions. You replace your old version with this improved version before the next study session.
The update process:
- Receive aggregated global model from server
- Replace local model weights with global model weights
- Optionally keep some personalized layers
- Ready for next round of local training
For example:
- Round 1: Trained local model, sent update
- Server aggregated all updates
- Round 2: Receive improved global model
- Use this as starting point for next round of training