Usage
This section outlines the required steps to integrate STADLE into a basic deep learning training process. Please refer to the Quickstart section to set up the client environment for connecting to STADLE. You can also download the sample code, which already includes STADLE integration, from here.
STADLE Aggregator Functionalities
STADLE aggregators can be configured through the stadle.ai dashboard, as explained in the Quickstart and User Guide.
Creating a Project
Once you create an account, the first step is to create a new project from the Overview page. Each project corresponds to one AI model architecture (e.g., VGG16). To federate different AI models, you must create separate projects—since the model architecture must remain consistent across all agents connected to a given aggregator.
Note
A free account allows the creation of only one project.
Initiating an Aggregator
After creating a project, you can initiate one or more aggregators from the Overview page. If you want to set up decentralized aggregation with multiple aggregators, you can launch several aggregator instances within the same project to enable the synthesis of Semi-Global Models (SG Models).
Note
A free account allows you to initiate only one aggregator.
Downloading Models
You can download the latest:
Global ML models
Local models
Best-performing models
These are accessible from your STADLE project dashboard.
Completing the Current Round
This feature allows you to force the aggregator to complete the current round of aggregation. Normally, an aggregator waits to collect a sufficient number of local models before proceeding. However, using the Complete Current Round option, you can manually trigger aggregation even if the collection threshold has not been met.
Aggregation Threshold
The Aggregation Threshold determines the proportion of local models required from active agents to proceed with aggregation. For example, a threshold of 0.7 means that 70% of the active agents must submit their models for aggregation to occur.
Agent Timeout
This feature disconnects inactive agents based on a timeout interval. If an agent is unresponsive for a user-defined number of seconds, it is marked as TIMEOUT and excluded from the aggregation process. If the agent reconnects, it will be included again.
A timeout of 0 disables this functionality.
Aggregation Method Selection
While FedAvg is the default aggregation algorithm, STADLE supports a variety of methods to suit different applications:
FedAvg
Geometric Median
Coordinate-Wise Median
Krum
Krum Averaging
Choose the method that best aligns with your model’s robustness and convergence requirements.
Synthesize Semi-Global Models
STADLE supports decentralized aggregation by allowing multiple aggregators to work together to produce Semi-Global Models (SG Models). This decentralized structure enables scalable and efficient global model synthesis across distributed clusters.
Aggregation Management
On the Aggregation Management page, you can monitor:
Current aggregation round
Maximum number of connectable agents
Number of active agents
Number of local models required for aggregation
Number of models already collected
Performance Tracking
Track the performance of local ML models on both the Dashboard and the Performance Tracking page. Metrics are recorded for each aggregation round to help you monitor training progress and model accuracy.
Stopping & Restarting Aggregators
Aggregators can be stopped or restarted from the Config Info & Settings page. Their status will update to INACTIVE or ACTIVE based on the action performed.
Client-side STADLE Integration
This section will cover the process of integrating STADLE with existing PyTorch code used to train a CNN on the CIFAR-10 dataset.
Local Training Code
The following is a breakdown of the PyTorch code serving as the example DL process:
1import sys
2
3import torch
4import torch.nn as nn
5import torch.optim as optim
6import torchvision
7import torchvision.transforms as transforms
8
9from vgg import VGG
This section imports sys and the requisite PyTorch libraries for future use. In addition, a predefined VGG model is imported from
the model definition file.
1transform_train = transforms.Compose([
2 transforms.RandomCrop(32, padding=4),
3 transforms.RandomHorizontalFlip(),
4 transforms.ToTensor(),
5 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
6])
7
8transform_test = transforms.Compose([
9 transforms.ToTensor(),
10 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
11])
12
13trainset = torchvision.datasets.CIFAR10(
14 root='data', train=True, download=True, transform=transform_train)
15trainloader = torch.utils.data.DataLoader(
16 trainset, batch_size=64, shuffle=True, num_workers=2)
17
18testset = torchvision.datasets.CIFAR10(
19 root='data', train=False, download=True, transform=transform_test)
20testloader = torch.utils.data.DataLoader(
21 testset, batch_size=64, shuffle=False, num_workers=2)
This section loads in the CIFAR-10 dataset (downloading it if necessary) and applies the transforms to each image to help augment the dataset for robust training.
1device = 'cuda'
2
3num_epochs = 200
4lr = 0.001
5momentum = 0.9
6
7model = VGG('VGG16').to(device)
8
9criterion = nn.CrossEntropyLoss()
10optimizer = optim.SGD(model.parameters(), lr=lr,
11 momentum=momentum, weight_decay=5e-4)
12scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
This section sets the device to perform training on (GPU in this case) and fixes some training-specific parameters. It then creates the initial model object and the PyTorch objects used to optimize the model parameters during the training process.
1for epoch in range(num_epochs):
2 print('\nEpoch: %d' % (epoch + 1))
3
4 model.train()
5 train_loss = 0
6 correct = 0
7 total = 0
8
9 for batch_idx, (inputs, targets) in enumerate(trainloader):
10 inputs, targets = inputs.to(device), targets.to(device)
11
12 optimizer.zero_grad()
13 outputs = model(inputs)
14 loss = criterion(outputs, targets)
15
16 loss.backward()
17 optimizer.step()
18
19 _, predicted = outputs.max(1)
20 total += targets.size(0)
21 correct += predicted.eq(targets).sum().item()
22
23 sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
24 print('\n')
25
26 if ((epoch + 0) % 5 == 0):
27 model.eval()
28 test_loss = 0
29 correct = 0
30 total = 0
31
32 with torch.no_grad():
33 for batch_idx, (inputs, targets) in enumerate(testloader):
34 inputs, targets = inputs.to(device), targets.to(device)
35 outputs = model(inputs)
36 loss = criterion(outputs, targets)
37
38 test_loss += loss.item()
39 _, predicted = outputs.max(1)
40 total += targets.size(0)
41 correct += predicted.eq(targets).sum().item()
42
43 acc = 100.*correct/total
44 print(f"Accuracy on val set: {acc}%")
Finally, this section handles the actual training of the model. Training on the train dataset occurs every epoch, and validation set accuracy is computed every five epochs.
In summary, this code trains the VGG-16 model on the CIFAR-10 dataset for 200 epochs.
Integration with BasicClient
In STADLE, the purpose of a client is to act as an interface between the model training being done locally
and the FL process managed by STADLE’s other components. BasicClient is an implementation of the STADLE
client, intended for cases where maximal control of the FL process or minimal integration are desired.
The process of integrating with STADLE using the BasicClient can be broken down into four steps:
Create and properly configure the BasicClient object
Connect the BasicClient to STADLE (via an aggregator)
Modify the training loop to send a model to STADLE after some period of local training and to wait to receive the aggregated model as a checkpoint to resume local training.
Disconnect from STADLE when training is complete
The CIFAR-10 example will be used to show how these steps can be implemented.
Step 1: Create/Configure BasicClient
First, BasicClient has to be imported from the stadle library; this is done with
1from stadle import BasicClient
The BasicClient object can then be created. The configuration information of the BasicClient can be set by passing a config file path through the constructor. Refer to Config File Documentation for details on the config file parameters.
1client_config_path = r"/path/to/config/file.json"
2stadle_client = BasicClient(config_file=client_config_path)
Alternatively, specific config parameter values can be set directly with the BasicClient constructor. Information on the config file and these parameters, as well as all subsequent function calls, can be found at Client API Documentation.
Step 2: Connect BasicClient to STADLE
The connection between the BasicClient and the aggregator it is configured to connect to can then be opened with
1stadle_client.connect(model)
Note that we pass the recently-intialized model (in this case, the VGG model) to the client for use as a container for the aggregated parameters received each round.
Step 3: Modify Training Loop
The local training code previously shown trains the VGG model for 200 epochs. In order to apply federated learning to this training process, these 200 epochs must be broken into numerous short local training periods. For this example, these local training periods will be two epochs long; thus, 100 aggregation rounds of two epochs each will be run.
After one such training period, all of the CIFAR-10 “agents” connected to an aggregator send their locally-trained models to the aggregator, waiting to receive the aggregated model before starting the next training period with the received model. The following shows an example of how this can be done within the main training loop of the local training code:
1for epoch in range(num_epochs):
2 print('\nEpoch: %d' % (epoch + 1))
3
4 """
5 Addition for STADLE integration
6 """
7 if (epoch % 2 == 0):
8 # Don't send model at beginning of training
9
10 if (epoch != 0):
11 stadle_client.send_trained_model(agent.target_net)
12
13 sg_model_dict = stadle_client.wait_for_sg_model()
14
15 model.load_state_dict(sg_model_dict)
16
17 model.train()
18 train_loss = 0
19 correct = 0
20 total = 0
21
22 for batch_idx, (inputs, targets) in enumerate(trainloader):
23 inputs, targets = inputs.to(device), targets.to(device)
24
25 optimizer.zero_grad()
26 outputs = model(inputs)
27 loss = criterion(outputs, targets)
28
29 loss.backward()
30 optimizer.step()
31
32 _, predicted = outputs.max(1)
33 total += targets.size(0)
34 correct += predicted.eq(targets).sum().item()
35
36 sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
37 print('\n')
Step 4: Disconnect from STADLE
Finally, the BasicClient can be disconnected with
1stadle_client.disconnect()
once all training rounds have completed or some other condition has been met.
Running Client-Side STADLE Components
After starting the required server-side STADLE components (i.e., the persistence server and one or more aggregators), the final step to initialize the FL process is to upload the base model to the STADLE persistence server. This base model is used internally to convert between specific ML frameworks and STADLE’s framework-agnostic model representation.
Uploading the Base Model via CLI
To upload the base model, use the stadle upload-model command along with a configuration file that includes the base model specification.
stadle upload-model --config_path <path_to_config_file>
The configuration file must include a base_model section structured as follows:
"base_model": {
"model_name": "CIFAR-10 VGG Model",
"model_fn_src": "vgg",
"model_fn": "VGG",
"model_fn_args": {
"vgg_name": "VGG16"
},
"model_format": "PyTorch"
}
Here is what each field represents:
“model_name”: Name to be associated with the base model object.
“model_fn_src”: Module containing the function that returns the model object. This can refer to a local file or an installed module.
“model_fn”: Name of the model class or function to instantiate
“model_fn_args”: Arguments passed as kwargs to model_fn (optional)
“model_format”: Type of model loaded by the model_fn function.
Assuming the file vgg.py contains a model class VGG, and the configuration file includes the above base_model block, you would upload the model like so:
stadle upload-model --config_path ./config/vgg_config.json
Once the upload is successful, you can verify the base model on your STADLE project dashboard.
Running Agents
After uploading the base model, you can begin running the agent process from the previous section to connect to the aggregator and participate in federated learning.
Each agent will communicate with the aggregator, send local model updates, and receive aggregated models in return.
Execution Order Summary
Start the persistence server
Start one or more aggregators
Run stadle upload-model with a configuration file containing the base_model block
Run agent(s) to participate in federated training