Usage
This section covers the requisite steps for integrating STADLE with a basic deep learning training process. Please refer to Quickstart to set up the client environment to connect to STADLE. Also, please download the sample codes from here in which the STADLE libraries are already integrated.
STADLE Aggregator Functionalities
The STADLE aggregators can be configured through the stadle.ai dashboard as explained in the Quickstart and its User Guide.
Creating Project
Once you create your own account, the first thing you will be doing is to create a new project on Overview page. In one project, you will be able to assign an AI model such as VGG16. If you would like to federate many AI models, you will have to create multiple projects for each AI model to be aggregated as the architecture of the AI model needs to be consistent among all the agents that are connected to the aggregator.
Note
With your free account, you will be able to create only one project.
Initiating Aggregator
Once your create a project, you will be able to initiate aggregator(s) on Overview page. If you would like to set up decentralized aggregation with multiple aggregators, you can initiate multiple aggregator instances within one project so that semi-global model will be created.
Note
With your free account, you will be able to initiate only one aggregator.
Downloading Models
You will be able to download the most recent global ML models as well as the most recent local models and bestperformance models on the STADLE dashboard.
Completing Current Round
This fanctionality provides the ability to wrap up the current round of aggregation. An aggregator needs to collect the certain number of ML models in order to proceed with the aggregation process. However, you can force the aggregation to happen even if there are not enough local models collected from agents by executing “Complete Current Round” functionality.
Aggregation Threshold
This specifies how much local models need to be collected over the active agents connected to the aggregator. For example, if the “Aggergation Threshold” is 0.7, we need 70% of local models from the active agents.
Agent Timeout
This feature provides the time out functionality that disconnects active agents if the aggregator has not heard from the agents after the seconds specified by the user. For example, if the timeout value is 30 and an agent is stopped or disconnected from the network for 30 seconds, the aggregator sets this agent’s status as TIMEOUT. If the agent’s status becomes TIMEOUT, this agent is not counted as an active agent and excluded from the aggregation process unless it is connected to the aggregator again. If the timeout value is 0, then this agent timeout functionality itself is disabled.
Aggregation Method Selection
While FedAvg is used as a default aggregation method as a powerful approach for many applications, you can pick up the most suitable aggregation method for your ML application. The aggregation methods that are currently supported include FedAvg, Geometric Median, Coordinate-Wise Median, Krum, and Krum Averaging.
Synthesize Semi-Global Models
STADLE supports decentralized architecture of aggregators where multiple aggregators can be set up to synthesize the another layer of global models, which we call Semi-Global Models (SG Models). Semi-Global Models are STADLE’s powerful approach to create the global model in a decentralized manner so that you can scale the federated learning horizontally.
Aggregation Management
On the Aggregation Management page, you will be able to check the Current Round, the Maximum Number of Connectable Active Agents, the Number of Active Agents Participating, the Number of Local Models Needed for Aggregation, and the Number of Local Models Collected.
Performance Tracking
Performance of the uploaded local ML models for each aggregation round can be tracked on the Dashboard as well as Performance Tracking page. You can monitor the learning process for each metrics of your ML models there.
Stopping & Restarting aggregators
You can stop/restart aggregators on the Config Info & Settings page. The aggregator status then becomes “INACTIVE” or “ACTIVE” after successfully stoping/restarting the aggregators, respectively.
Client-side STADLE Integration
This section will cover the process of integrating STADLE with existing PyTorch code used to train a CNN on the CIFAR-10 dataset.
Local Training Code
The following is a breakdown of the PyTorch code serving as the example DL process:
1import sys
2
3import torch
4import torch.nn as nn
5import torch.optim as optim
6import torchvision
7import torchvision.transforms as transforms
8
9from vgg import VGG
This section imports sys
and the requisite PyTorch libraries for future use. In addition, a predefined VGG model is imported from
the model definition file.
1transform_train = transforms.Compose([
2 transforms.RandomCrop(32, padding=4),
3 transforms.RandomHorizontalFlip(),
4 transforms.ToTensor(),
5 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
6])
7
8transform_test = transforms.Compose([
9 transforms.ToTensor(),
10 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
11])
12
13trainset = torchvision.datasets.CIFAR10(
14 root='data', train=True, download=True, transform=transform_train)
15trainloader = torch.utils.data.DataLoader(
16 trainset, batch_size=64, shuffle=True, num_workers=2)
17
18testset = torchvision.datasets.CIFAR10(
19 root='data', train=False, download=True, transform=transform_test)
20testloader = torch.utils.data.DataLoader(
21 testset, batch_size=64, shuffle=False, num_workers=2)
This section loads in the CIFAR-10 dataset (downloading it if necessary) and applies the transforms to each image to help augment the dataset for robust training.
1device = 'cuda'
2
3num_epochs = 200
4lr = 0.001
5momentum = 0.9
6
7model = VGG('VGG16').to(device)
8
9criterion = nn.CrossEntropyLoss()
10optimizer = optim.SGD(model.parameters(), lr=lr,
11 momentum=momentum, weight_decay=5e-4)
12scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
This section sets the device to perform training on (GPU in this case) and fixes some training-specific parameters. It then creates the initial model object and the PyTorch objects used to optimize the model parameters during the training process.
1for epoch in range(num_epochs):
2 print('\nEpoch: %d' % (epoch + 1))
3
4 model.train()
5 train_loss = 0
6 correct = 0
7 total = 0
8
9 for batch_idx, (inputs, targets) in enumerate(trainloader):
10 inputs, targets = inputs.to(device), targets.to(device)
11
12 optimizer.zero_grad()
13 outputs = model(inputs)
14 loss = criterion(outputs, targets)
15
16 loss.backward()
17 optimizer.step()
18
19 _, predicted = outputs.max(1)
20 total += targets.size(0)
21 correct += predicted.eq(targets).sum().item()
22
23 sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
24 print('\n')
25
26 if ((epoch + 0) % 5 == 0):
27 model.eval()
28 test_loss = 0
29 correct = 0
30 total = 0
31
32 with torch.no_grad():
33 for batch_idx, (inputs, targets) in enumerate(testloader):
34 inputs, targets = inputs.to(device), targets.to(device)
35 outputs = model(inputs)
36 loss = criterion(outputs, targets)
37
38 test_loss += loss.item()
39 _, predicted = outputs.max(1)
40 total += targets.size(0)
41 correct += predicted.eq(targets).sum().item()
42
43 acc = 100.*correct/total
44 print(f"Accuracy on val set: {acc}%")
Finally, this section handles the actual training of the model. Training on the train dataset occurs every epoch, and validation set accuracy is computed every five epochs.
In summary, this code trains the VGG-16 model on the CIFAR-10 dataset for 200 epochs.
Integration with BasicClient
In STADLE, the purpose of a client is to act as an interface between the model training being done locally
and the FL process managed by STADLE’s other components. BasicClient
is an implementation of the STADLE
client, intended for cases where maximal control of the FL process or minimal integration are desired.
The process of integrating with STADLE using the BasicClient can be broken down into four steps:
Create and properly configure the BasicClient object
Connect the BasicClient to STADLE (via an aggregator)
Modify the training loop to send a model to STADLE after some period of local training and to wait to receive the aggregated model as a checkpoint to resume local training.
Disconnect from STADLE when training is complete
The CIFAR-10 example will be used to show how these steps can be implemented.
Step 1: Create/Configure BasicClient
First, BasicClient has to be imported from the stadle
library; this is done with
1from stadle import BasicClient
The BasicClient object can then be created. The configuration information of the BasicClient can be set by passing a config file path through the constructor. Refer to Config File Documentation for details on the config file parameters.
1client_config_path = r"/path/to/config/file.json"
2stadle_client = BasicClient(config_file=client_config_path)
Alternatively, specific config parameter values can be set directly with the BasicClient constructor. Information on the config file and these parameters, as well as all subsequent function calls, can be found at Client API Documentation.
Step 2: Connect BasicClient to STADLE
The connection between the BasicClient and the aggregator it is configured to connect to can then be opened with
1stadle_client.connect(model)
Note that we pass the recently-intialized model (in this case, the VGG model) to the client for use as a container for the aggregated parameters received each round.
Step 3: Modify Training Loop
The local training code previously shown trains the VGG model for 200 epochs. In order to apply federated learning to this training process, these 200 epochs must be broken into numerous short local training periods. For this example, these local training periods will be two epochs long; thus, 100 aggregation rounds of two epochs each will be run.
After one such training period, all of the CIFAR-10 “agents” connected to an aggregator send their locally-trained models to the aggregator, waiting to receive the aggregated model before starting the next training period with the received model. The following shows an example of how this can be done within the main training loop of the local training code:
1for epoch in range(num_epochs):
2 print('\nEpoch: %d' % (epoch + 1))
3
4 """
5 Addition for STADLE integration
6 """
7 if (epoch % 2 == 0):
8 # Don't send model at beginning of training
9
10 if (epoch != 0):
11 stadle_client.send_trained_model(agent.target_net)
12
13 sg_model_dict = stadle_client.wait_for_sg_model()
14
15 model.load_state_dict(sg_model_dict)
16
17 model.train()
18 train_loss = 0
19 correct = 0
20 total = 0
21
22 for batch_idx, (inputs, targets) in enumerate(trainloader):
23 inputs, targets = inputs.to(device), targets.to(device)
24
25 optimizer.zero_grad()
26 outputs = model(inputs)
27 loss = criterion(outputs, targets)
28
29 loss.backward()
30 optimizer.step()
31
32 _, predicted = outputs.max(1)
33 total += targets.size(0)
34 correct += predicted.eq(targets).sum().item()
35
36 sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
37 print('\n')
Step 4: Disconnect from STADLE
Finally, the BasicClient can be disconnected with
1stadle_client.disconnect()
once all training rounds have completed or some other condition has been met.
Integration with IntegratedClient
Using the IntegratedClient
allows for the management of the local training process to be passed to STADLE,
as opposed to the more hands-off approach taken by the BasicClient. As a result, the integration process to
be able to use the IntegratedClient is slightly more in-depth.
This process can be broken down into x steps:
Create and properly configure the IntegratedClient object
Construct a training, cross-validation, and test function (segmentation of the local training process) and pass the functions to the IntegratedClient
Construct a termination function to determine when to stop the FL process
Connect the IntegratedClient to STADLE and start the entire FL process
Similarly to the BasicClient, the CIFAR-10 example will be used to show how these steps can be implemented.
Step 1: Create/Configure IntegratedClient
IntegratedClient is imported from the stadle
library; this is done with
1from stadle import IntegratedClient
The BasicClient object can then be created and configured like the BasicClient:
1client_config_path = r"/path/to/config/file.json"
2stadle_client = IntegratedClient(config_file=client_config_path)
Alternatively, specific config parameter values can be set directly with the IntegratedClient constructor. Information on the config file and these parameters, as well as all subsequent function calls, can be found at Client API Documentation.
Step 2: Construct Local Training Functions
When STADLE manages the local training part of the FL process, it works with abstracted versions of the training, cross-validation, and test functions. As a result, any specific implementations of these functions must match these abstractions in format. The following are template implementations of the functions in question:
Train Function:
1def train(model, data, **kwargs):
2 # Use data to locally train model
3 # kwargs used to pass general parameters to function
4
5 return locally_trained_model, average_training_loss
Cross-Validation Function:
1def cross_validate(model, data, **kwargs):
2 # Use data to compute accuracy or other performance metric (validation set)
3 # kwargs used to pass general parameters to function
4
5 return acc, ave_loss
Test Function:
1def test(model, data, **kwargs):
2 # Use data to compute accuracy or other performance metric (test set)
3 # kwargs used to pass general parameters to function
4
5 return acc, ave_loss
The IntegratedClient will go through the following steps to fulfill the agent-side role in FL:
Check termination function output, continue if false
Receive previous round aggregated model from aggregator
Run cross_validate function on aggregated model
Run train function to train model locally
Run cross_validate function on locally-trained model
Send locally-trained model to aggregator
The CIFAR-10 local training example code can then be segmented into these functions in the following way:
Train Function (CIFAR-10):
1def train(model, data, **kwargs):
2 lr = float(kwargs.get("lr")) if kwargs.get("lr") else 0.001
3 momentum = float(kwargs.get("momentum")) if kwargs.get("momentum") else 0.9
4 epochs = int(kwargs.get("epochs")) if kwargs.get("epochs") else 2
5 device = kwargs.get("device") if kwargs.get("device") else 'cpu'
6
7 model = model.to(device)
8
9 criterion = nn.CrossEntropyLoss()
10 optimizer = optim.SGD(model.parameters(), lr=lr,
11 momentum=momentum, weight_decay=5e-4)
12 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
13
14 ave_loss = []
15
16 for epoch in range(epochs): # loop over the dataset multiple times
17
18 print('\nEpoch: %d' % (epoch + 1))
19
20 model.train()
21 train_loss = 0
22 correct = 0
23 total = 0
24 for batch_idx, (inputs, targets) in enumerate(trainloader):
25 inputs, targets = inputs.to(device), targets.to(device)
26
27 optimizer.zero_grad()
28 outputs = model(inputs)
29 loss = criterion(outputs, targets)
30
31 loss.backward()
32 optimizer.step()
33
34 train_loss += loss.item()
35 ave_loss.append(train_loss)
36 _, predicted = outputs.max(1)
37 total += targets.size(0)
38 correct += predicted.eq(targets).sum().item()
39
40 ave_loss = sum(ave_loss) / len(ave_loss)
41
42 model = model.to('cpu')
43
44 return model, ave_loss
Cross-Validation Function (CIFAR-10):
1def cross_validate(test_model, data, **kwargs):
2 device = kwargs.get("device") if kwargs.get("device") else 'cpu'
3
4 test_model = test_model.to(device)
5
6 correct = 0
7 total = 0
8 overall_accuracy = 0
9
10 with torch.no_grad():
11 for (inputs, targets) in data:
12 inputs, targets = inputs.to(device), targets.to(device)
13 # calculate outputs by running images through the network
14 outputs = test_model(inputs)
15 # the class with the highest energy is what we choose as prediction
16 _, predicted = torch.max(outputs.data, 1)
17 total += targets.size(0)
18 correct += (predicted == targets).sum().item()
19
20 overall_accuracy = 100 * correct / total
21 print('Accuracy of the network on the 10000 test images: %d %%' % (overall_accuracy))
22
23 # prepare to count predictions for each class
24 correct_pred = {classname: 0 for classname in classes}
25 total_pred = {classname: 0 for classname in classes}
26
27 with torch.no_grad():
28 for (inputs, targets) in data:
29 inputs, targets = inputs.to(device), targets.to(device)
30 outputs = test_model(inputs)
31 _, predictions = torch.max(outputs, 1)
32 # collect the correct predictions for each class
33 for target, prediction in zip(targets, predictions):
34 if prediction == target:
35 correct_pred[classes[target]] += 1
36 total_pred[classes[target]] += 1
37
38 # print accuracy for each class
39 # Capture average accuracy across all classes
40 for classname, correct_count in correct_pred.items():
41 accuracy = 100 * float(correct_count) / total_pred[classname]
42 print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
43 accuracy))
44 return overall_accuracy, 0
We can use the same implementation for the test function in this case, simply changing the dataset passed to the function.
Test Function (CIFAR-10):
1def test(test_model, data, **kwargs):
2 device = kwargs.get("device") if kwargs.get("device") else 'cpu'
3
4 test_model = test_model.to(device)
5
6 correct = 0
7 total = 0
8 overall_accuracy = 0
9
10 with torch.no_grad():
11 for (inputs, targets) in data:
12 inputs, targets = inputs.to(device), targets.to(device)
13 # calculate outputs by running images through the network
14 outputs = test_model(inputs)
15 # the class with the highest energy is what we choose as prediction
16 _, predicted = torch.max(outputs.data, 1)
17 total += targets.size(0)
18 correct += (predicted == targets).sum().item()
19
20 overall_accuracy = 100 * correct / total
21 print('Accuracy of the network on the 10000 test images: %d %%' % (overall_accuracy))
22
23 # prepare to count predictions for each class
24 correct_pred = {classname: 0 for classname in classes}
25 total_pred = {classname: 0 for classname in classes}
26
27 with torch.no_grad():
28 for (inputs, targets) in data:
29 inputs, targets = inputs.to(device), targets.to(device)
30 outputs = test_model(inputs)
31 _, predictions = torch.max(outputs, 1)
32 # collect the correct predictions for each class
33 for target, prediction in zip(targets, predictions):
34 if prediction == target:
35 correct_pred[classes[target]] += 1
36 total_pred[classes[target]] += 1
37
38 # print accuracy for each class
39 # Capture average accuracy across all classes
40 for classname, correct_count in correct_pred.items():
41 accuracy = 100 * float(correct_count) / total_pred[classname]
42 print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
43 accuracy))
44 return overall_accuracy, 0
Step 3: Construct Termination Function
The termination function is a user-defined function that controls when an agent exits a FL process. The function is run by the agent at the beginning of each round, and the agent exits if the function retuns True.
One simple termination function is to return True after a certain number of rounds has passed; the following is an implementation of such a function:
1def judge_termination(**kwargs) -> bool:
2 """
3 Decide if it finishes training process and exits from FL platform
4 :param training_count: int - the number of training done
5 :param sg_arrival_count: int - the number of times it received SG models
6 :return: bool - True if it continues the training loop; False if it stops
7 """
8
9 keep_running = True
10 client = kwargs.get('client')
11 current_fl_round = client.federated_training_round
12
13 if current_fl_round >= int(kwargs.get("round_to_exit")):
14 keep_running = False
15 client.stop_model_exchange_routine()
16 return keep_running
Step 4: Setup, Connect IntegratedClient to STADLE
The following is example code to set up the IntegratedClient with the previously defined functions and start the FL process:
1parser = argparse.ArgumentParser(description='STADLE CIFAR10 Training')
2parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
3parser.add_argument('--lt_epochs', default=3)
4
5args = parser.parse_args()
6
7device = 'cuda'
8
9model = VGG('VGG16')
Read in learning rate and number of local training epochs from command line arguments, set training device and define model to be trained.
1trainset = torchvision.datasets.CIFAR10(
2 root='data', train=True, download=True, transform=transform_train)
3trainloader = torch.utils.data.DataLoader(
4 trainset, batch_size=64, shuffle=True, num_workers=2)
5
6testset = torchvision.datasets.CIFAR10(
7 root='data', train=False, download=True, transform=transform_test)
8testloader = torch.utils.data.DataLoader(
9 testset, batch_size=64, shuffle=False, num_workers=2)
Use the same CIFAR-10 datasets as the local training example
1stadle_client.set_termination_function(judge_termination, round_to_exit=20, client=stadle_client)
2stadle_client.set_training_function(train, trainloader, lr=args.lr, epochs=args.lt_epochs, device=device, agent_name=args.agent_name)
3stadle_client.set_cross_validation_function(cross_validate, testloader, device=device)
4stadle_client.set_testing_function(test, testloader)
Pass functions to IntegratedClient for use in internal training loop
1stadle_client.set_bm_obj(model)
2stadle_client.start()
Set the container model for the client, then start the agent FL process
Running Client-Side STADLE Components
After starting the requisite server-side STADLE components, there is one final step that must be run to fully initialize an FL process with STADLE and prepare for agent connections. The component responsible for this is called the admin agent - its role in this case is to send the model structure and information to the persistence server for use in converting between specific model frameworks and the framework-agnostic model representation used by STADLE. The following is example admin agent code for the CIFAR-10 example:
1from stadle import AdminAgent
2from stadle import BaseModelConvFormat
3from stadle.lib.entity.model import BaseModel
4from stadle.lib.util import admin_arg_parser
5
6from vgg import VGG
This section imports the required objects from STADLE, as well as a function for reading command line arguments and the VGG model. The BaseModel object acts as a container for information on the model being trained with STADLE, and is passed to the AdminAgent to be sent to the persistence server.
1base_model = BaseModel("PyTorch-CIFAR10-Model", VGG('VGG16'), BaseModelConvFormat.pytorch_format)
The specific BaseModel object is then created with the VGG16 model structure and information.
1args = admin_arg_parser()
2admin_agent = AdminAgent(config_file=args.config_path, simulation_flag=args.simulation,
3 aggregator_ip_address=args.ip_address, reg_socket=args.reg_port,
4 exch_socket=args.exch_port, model_path=args.model_path, base_model=base_model,
5 agent_running=args.agent_running)
6
7admin_agent.preload()
8admin_agent.initialize()
The command line arguments are parsed and used to create the AdminAgent object, along with the base model. The preload function prepares the base model to be sent (converting to agnostic representation internally) and the initialize function sends the base model information, preparing all of the aggregators to connect to agents by extension.
After the admin agent is run, the main agent client-side code can freely be run. In summary, the order to run components is as follows:
Start persistence server
Start aggregator(s)
Run admin agent (only once)
Run agent(s) - client-side code