Usage

This section covers the requisite steps for integrating STADLE with a basic deep learning training process. Please refer to Quickstart to set up the client environment to connect to STADLE. Also, please download the sample codes from here in which the STADLE libraries are already integrated.

STADLE Aggregator Functionalities

The STADLE aggregators can be configured through the stadle.ai dashboard as explained in the Quickstart and its User Guide.

Creating Project

Once you create your own account, the first thing you will be doing is to create a new project on Overview page. In one project, you will be able to assign an AI model such as VGG16. If you would like to federate many AI models, you will have to create multiple projects for each AI model to be aggregated as the architecture of the AI model needs to be consistent among all the agents that are connected to the aggregator.

Note

With your free account, you will be able to create only one project.

Initiating Aggregator

Once your create a project, you will be able to initiate aggregator(s) on Overview page. If you would like to set up decentralized aggregation with multiple aggregators, you can initiate multiple aggregator instances within one project so that semi-global model will be created.

Note

With your free account, you will be able to initiate only one aggregator.

Downloading Models

You will be able to download the most recent global ML models as well as the most recent local models and bestperformance models on the STADLE dashboard.

Completing Current Round

This fanctionality provides the ability to wrap up the current round of aggregation. An aggregator needs to collect the certain number of ML models in order to proceed with the aggregation process. However, you can force the aggregation to happen even if there are not enough local models collected from agents by executing “Complete Current Round” functionality.

Aggregation Threshold

This specifies how much local models need to be collected over the active agents connected to the aggregator. For example, if the “Aggergation Threshold” is 0.7, we need 70% of local models from the active agents.

Agent Timeout

This feature provides the time out functionality that disconnects active agents if the aggregator has not heard from the agents after the seconds specified by the user. For example, if the timeout value is 30 and an agent is stopped or disconnected from the network for 30 seconds, the aggregator sets this agent’s status as TIMEOUT. If the agent’s status becomes TIMEOUT, this agent is not counted as an active agent and excluded from the aggregation process unless it is connected to the aggregator again. If the timeout value is 0, then this agent timeout functionality itself is disabled.

Aggregation Method Selection

While FedAvg is used as a default aggregation method as a powerful approach for many applications, you can pick up the most suitable aggregation method for your ML application. The aggregation methods that are currently supported include FedAvg, Geometric Median, Coordinate-Wise Median, Krum, and Krum Averaging.

Synthesize Semi-Global Models

STADLE supports decentralized architecture of aggregators where multiple aggregators can be set up to synthesize the another layer of global models, which we call Semi-Global Models (SG Models). Semi-Global Models are STADLE’s powerful approach to create the global model in a decentralized manner so that you can scale the federated learning horizontally.

Aggregation Management

On the Aggregation Management page, you will be able to check the Current Round, the Maximum Number of Connectable Active Agents, the Number of Active Agents Participating, the Number of Local Models Needed for Aggregation, and the Number of Local Models Collected.

Performance Tracking

Performance of the uploaded local ML models for each aggregation round can be tracked on the Dashboard as well as Performance Tracking page. You can monitor the learning process for each metrics of your ML models there.

Stopping & Restarting aggregators

You can stop/restart aggregators on the Config Info & Settings page. The aggregator status then becomes “INACTIVE” or “ACTIVE” after successfully stoping/restarting the aggregators, respectively.

Client-side STADLE Integration

This section will cover the process of integrating STADLE with existing PyTorch code used to train a CNN on the CIFAR-10 dataset.

Local Training Code

The following is a breakdown of the PyTorch code serving as the example DL process:

1import sys
2
3import torch
4import torch.nn as nn
5import torch.optim as optim
6import torchvision
7import torchvision.transforms as transforms
8
9from vgg import VGG

This section imports sys and the requisite PyTorch libraries for future use. In addition, a predefined VGG model is imported from the model definition file.

 1transform_train = transforms.Compose([
 2    transforms.RandomCrop(32, padding=4),
 3    transforms.RandomHorizontalFlip(),
 4    transforms.ToTensor(),
 5    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
 6])
 7
 8transform_test = transforms.Compose([
 9    transforms.ToTensor(),
10    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
11])
12
13trainset = torchvision.datasets.CIFAR10(
14    root='data', train=True, download=True, transform=transform_train)
15trainloader = torch.utils.data.DataLoader(
16    trainset, batch_size=64, shuffle=True, num_workers=2)
17
18testset = torchvision.datasets.CIFAR10(
19    root='data', train=False, download=True, transform=transform_test)
20testloader = torch.utils.data.DataLoader(
21    testset, batch_size=64, shuffle=False, num_workers=2)

This section loads in the CIFAR-10 dataset (downloading it if necessary) and applies the transforms to each image to help augment the dataset for robust training.

 1device = 'cuda'
 2
 3num_epochs = 200
 4lr = 0.001
 5momentum = 0.9
 6
 7model = VGG('VGG16').to(device)
 8
 9criterion = nn.CrossEntropyLoss()
10optimizer = optim.SGD(model.parameters(), lr=lr,
11                      momentum=momentum, weight_decay=5e-4)
12scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

This section sets the device to perform training on (GPU in this case) and fixes some training-specific parameters. It then creates the initial model object and the PyTorch objects used to optimize the model parameters during the training process.

 1for epoch in range(num_epochs):
 2    print('\nEpoch: %d' % (epoch + 1))
 3
 4    model.train()
 5    train_loss = 0
 6    correct = 0
 7    total = 0
 8
 9    for batch_idx, (inputs, targets) in enumerate(trainloader):
10        inputs, targets = inputs.to(device), targets.to(device)
11
12        optimizer.zero_grad()
13        outputs = model(inputs)
14        loss = criterion(outputs, targets)
15
16        loss.backward()
17        optimizer.step()
18
19        _, predicted = outputs.max(1)
20        total += targets.size(0)
21        correct += predicted.eq(targets).sum().item()
22
23        sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
24    print('\n')
25
26    if ((epoch + 0) % 5 == 0):
27        model.eval()
28        test_loss = 0
29        correct = 0
30        total = 0
31
32        with torch.no_grad():
33            for batch_idx, (inputs, targets) in enumerate(testloader):
34                inputs, targets = inputs.to(device), targets.to(device)
35                outputs = model(inputs)
36                loss = criterion(outputs, targets)
37
38                test_loss += loss.item()
39                _, predicted = outputs.max(1)
40                total += targets.size(0)
41                correct += predicted.eq(targets).sum().item()
42
43        acc = 100.*correct/total
44        print(f"Accuracy on val set: {acc}%")

Finally, this section handles the actual training of the model. Training on the train dataset occurs every epoch, and validation set accuracy is computed every five epochs.

In summary, this code trains the VGG-16 model on the CIFAR-10 dataset for 200 epochs.

Integration with BasicClient

In STADLE, the purpose of a client is to act as an interface between the model training being done locally and the FL process managed by STADLE’s other components. BasicClient is an implementation of the STADLE client, intended for cases where maximal control of the FL process or minimal integration are desired.

The process of integrating with STADLE using the BasicClient can be broken down into four steps:

  1. Create and properly configure the BasicClient object

  2. Connect the BasicClient to STADLE (via an aggregator)

  3. Modify the training loop to send a model to STADLE after some period of local training and to wait to receive the aggregated model as a checkpoint to resume local training.

  4. Disconnect from STADLE when training is complete

The CIFAR-10 example will be used to show how these steps can be implemented.

Step 1: Create/Configure BasicClient

First, BasicClient has to be imported from the stadle library; this is done with

1from stadle import BasicClient

The BasicClient object can then be created. The configuration information of the BasicClient can be set by passing a config file path through the constructor. Refer to Config File Documentation for details on the config file parameters.

1client_config_path = r"/path/to/config/file.json"
2stadle_client = BasicClient(config_file=client_config_path)

Alternatively, specific config parameter values can be set directly with the BasicClient constructor. Information on the config file and these parameters, as well as all subsequent function calls, can be found at Client API Documentation.

Step 2: Connect BasicClient to STADLE

The connection between the BasicClient and the aggregator it is configured to connect to can then be opened with

1stadle_client.connect(model)

Note that we pass the recently-intialized model (in this case, the VGG model) to the client for use as a container for the aggregated parameters received each round.

Step 3: Modify Training Loop

The local training code previously shown trains the VGG model for 200 epochs. In order to apply federated learning to this training process, these 200 epochs must be broken into numerous short local training periods. For this example, these local training periods will be two epochs long; thus, 100 aggregation rounds of two epochs each will be run.

After one such training period, all of the CIFAR-10 “agents” connected to an aggregator send their locally-trained models to the aggregator, waiting to receive the aggregated model before starting the next training period with the received model. The following shows an example of how this can be done within the main training loop of the local training code:

 1for epoch in range(num_epochs):
 2    print('\nEpoch: %d' % (epoch + 1))
 3
 4    """
 5    Addition for STADLE integration
 6    """
 7    if (epoch % 2 == 0):
 8        # Don't send model at beginning of training
 9
10    if (epoch != 0):
11        stadle_client.send_trained_model(agent.target_net)
12
13    sg_model_dict = stadle_client.wait_for_sg_model()
14
15    model.load_state_dict(sg_model_dict)
16
17    model.train()
18    train_loss = 0
19    correct = 0
20    total = 0
21
22    for batch_idx, (inputs, targets) in enumerate(trainloader):
23        inputs, targets = inputs.to(device), targets.to(device)
24
25        optimizer.zero_grad()
26        outputs = model(inputs)
27        loss = criterion(outputs, targets)
28
29        loss.backward()
30        optimizer.step()
31
32        _, predicted = outputs.max(1)
33        total += targets.size(0)
34        correct += predicted.eq(targets).sum().item()
35
36        sys.stdout.write('\r'+f"\rEpoch Accuracy: {(100*correct/total):.2f}%")
37    print('\n')

Step 4: Disconnect from STADLE

Finally, the BasicClient can be disconnected with

1stadle_client.disconnect()

once all training rounds have completed or some other condition has been met.

Integration with IntegratedClient

Using the IntegratedClient allows for the management of the local training process to be passed to STADLE, as opposed to the more hands-off approach taken by the BasicClient. As a result, the integration process to be able to use the IntegratedClient is slightly more in-depth.

This process can be broken down into x steps:

  1. Create and properly configure the IntegratedClient object

  2. Construct a training, cross-validation, and test function (segmentation of the local training process) and pass the functions to the IntegratedClient

  3. Construct a termination function to determine when to stop the FL process

  4. Connect the IntegratedClient to STADLE and start the entire FL process

Similarly to the BasicClient, the CIFAR-10 example will be used to show how these steps can be implemented.

Step 1: Create/Configure IntegratedClient

IntegratedClient is imported from the stadle library; this is done with

1from stadle import IntegratedClient

The BasicClient object can then be created and configured like the BasicClient:

1client_config_path = r"/path/to/config/file.json"
2stadle_client = IntegratedClient(config_file=client_config_path)

Alternatively, specific config parameter values can be set directly with the IntegratedClient constructor. Information on the config file and these parameters, as well as all subsequent function calls, can be found at Client API Documentation.

Step 2: Construct Local Training Functions

When STADLE manages the local training part of the FL process, it works with abstracted versions of the training, cross-validation, and test functions. As a result, any specific implementations of these functions must match these abstractions in format. The following are template implementations of the functions in question:

Train Function:

1def train(model, data, **kwargs):
2    # Use data to locally train model
3    # kwargs used to pass general parameters to function
4
5    return locally_trained_model, average_training_loss

Cross-Validation Function:

1def cross_validate(model, data, **kwargs):
2    # Use data to compute accuracy or other performance metric (validation set)
3    # kwargs used to pass general parameters to function
4
5    return acc, ave_loss

Test Function:

1def test(model, data, **kwargs):
2    # Use data to compute accuracy or other performance metric (test set)
3    # kwargs used to pass general parameters to function
4
5    return acc, ave_loss

The IntegratedClient will go through the following steps to fulfill the agent-side role in FL:

  1. Check termination function output, continue if false

  2. Receive previous round aggregated model from aggregator

  3. Run cross_validate function on aggregated model

  4. Run train function to train model locally

  5. Run cross_validate function on locally-trained model

  6. Send locally-trained model to aggregator

The CIFAR-10 local training example code can then be segmented into these functions in the following way:

Train Function (CIFAR-10):

 1def train(model, data, **kwargs):
 2    lr = float(kwargs.get("lr")) if kwargs.get("lr") else 0.001
 3    momentum = float(kwargs.get("momentum")) if kwargs.get("momentum") else 0.9
 4    epochs = int(kwargs.get("epochs")) if kwargs.get("epochs") else 2
 5    device = kwargs.get("device") if kwargs.get("device") else 'cpu'
 6
 7    model = model.to(device)
 8
 9    criterion = nn.CrossEntropyLoss()
10    optimizer = optim.SGD(model.parameters(), lr=lr,
11                              momentum=momentum, weight_decay=5e-4)
12    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
13
14    ave_loss = []
15
16    for epoch in range(epochs):  # loop over the dataset multiple times
17
18        print('\nEpoch: %d' % (epoch + 1))
19
20        model.train()
21        train_loss = 0
22        correct = 0
23        total = 0
24        for batch_idx, (inputs, targets) in enumerate(trainloader):
25            inputs, targets = inputs.to(device), targets.to(device)
26
27            optimizer.zero_grad()
28            outputs = model(inputs)
29            loss = criterion(outputs, targets)
30
31            loss.backward()
32            optimizer.step()
33
34            train_loss += loss.item()
35            ave_loss.append(train_loss)
36            _, predicted = outputs.max(1)
37            total += targets.size(0)
38            correct += predicted.eq(targets).sum().item()
39
40    ave_loss = sum(ave_loss) / len(ave_loss)
41
42    model = model.to('cpu')
43
44    return model, ave_loss

Cross-Validation Function (CIFAR-10):

 1def cross_validate(test_model, data, **kwargs):
 2    device = kwargs.get("device") if kwargs.get("device") else 'cpu'
 3
 4    test_model = test_model.to(device)
 5
 6    correct = 0
 7    total = 0
 8    overall_accuracy = 0
 9
10    with torch.no_grad():
11        for (inputs, targets) in data:
12            inputs, targets = inputs.to(device), targets.to(device)
13            # calculate outputs by running images through the network
14            outputs = test_model(inputs)
15            # the class with the highest energy is what we choose as prediction
16            _, predicted = torch.max(outputs.data, 1)
17            total += targets.size(0)
18            correct += (predicted == targets).sum().item()
19
20    overall_accuracy = 100 * correct / total
21    print('Accuracy of the network on the 10000 test images: %d %%' % (overall_accuracy))
22
23    # prepare to count predictions for each class
24    correct_pred = {classname: 0 for classname in classes}
25    total_pred = {classname: 0 for classname in classes}
26
27    with torch.no_grad():
28        for (inputs, targets) in data:
29            inputs, targets = inputs.to(device), targets.to(device)
30            outputs = test_model(inputs)
31            _, predictions = torch.max(outputs, 1)
32            # collect the correct predictions for each class
33            for target, prediction in zip(targets, predictions):
34                if prediction == target:
35                    correct_pred[classes[target]] += 1
36                total_pred[classes[target]] += 1
37
38    # print accuracy for each class
39    # Capture average accuracy across all classes
40    for classname, correct_count in correct_pred.items():
41        accuracy = 100 * float(correct_count) / total_pred[classname]
42        print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
43                                                             accuracy))
44    return overall_accuracy, 0

We can use the same implementation for the test function in this case, simply changing the dataset passed to the function.

Test Function (CIFAR-10):

 1def test(test_model, data, **kwargs):
 2    device = kwargs.get("device") if kwargs.get("device") else 'cpu'
 3
 4    test_model = test_model.to(device)
 5
 6    correct = 0
 7    total = 0
 8    overall_accuracy = 0
 9
10    with torch.no_grad():
11        for (inputs, targets) in data:
12            inputs, targets = inputs.to(device), targets.to(device)
13            # calculate outputs by running images through the network
14            outputs = test_model(inputs)
15            # the class with the highest energy is what we choose as prediction
16            _, predicted = torch.max(outputs.data, 1)
17            total += targets.size(0)
18            correct += (predicted == targets).sum().item()
19
20    overall_accuracy = 100 * correct / total
21    print('Accuracy of the network on the 10000 test images: %d %%' % (overall_accuracy))
22
23    # prepare to count predictions for each class
24    correct_pred = {classname: 0 for classname in classes}
25    total_pred = {classname: 0 for classname in classes}
26
27    with torch.no_grad():
28        for (inputs, targets) in data:
29            inputs, targets = inputs.to(device), targets.to(device)
30            outputs = test_model(inputs)
31            _, predictions = torch.max(outputs, 1)
32            # collect the correct predictions for each class
33            for target, prediction in zip(targets, predictions):
34                if prediction == target:
35                    correct_pred[classes[target]] += 1
36                total_pred[classes[target]] += 1
37
38    # print accuracy for each class
39    # Capture average accuracy across all classes
40    for classname, correct_count in correct_pred.items():
41        accuracy = 100 * float(correct_count) / total_pred[classname]
42        print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
43                                                             accuracy))
44    return overall_accuracy, 0

Step 3: Construct Termination Function

The termination function is a user-defined function that controls when an agent exits a FL process. The function is run by the agent at the beginning of each round, and the agent exits if the function retuns True.

One simple termination function is to return True after a certain number of rounds has passed; the following is an implementation of such a function:

 1def judge_termination(**kwargs) -> bool:
 2    """
 3    Decide if it finishes training process and exits from FL platform
 4    :param training_count: int - the number of training done
 5    :param sg_arrival_count: int - the number of times it received SG models
 6    :return: bool - True if it continues the training loop; False if it stops
 7    """
 8
 9    keep_running = True
10    client = kwargs.get('client')
11    current_fl_round = client.federated_training_round
12
13    if current_fl_round >= int(kwargs.get("round_to_exit")):
14        keep_running = False
15        client.stop_model_exchange_routine()
16    return keep_running

Step 4: Setup, Connect IntegratedClient to STADLE

The following is example code to set up the IntegratedClient with the previously defined functions and start the FL process:

1parser = argparse.ArgumentParser(description='STADLE CIFAR10 Training')
2parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
3parser.add_argument('--lt_epochs', default=3)
4
5args = parser.parse_args()
6
7device = 'cuda'
8
9model = VGG('VGG16')

Read in learning rate and number of local training epochs from command line arguments, set training device and define model to be trained.

1trainset = torchvision.datasets.CIFAR10(
2        root='data', train=True, download=True, transform=transform_train)
3trainloader = torch.utils.data.DataLoader(
4        trainset, batch_size=64, shuffle=True, num_workers=2)
5
6testset = torchvision.datasets.CIFAR10(
7        root='data', train=False, download=True, transform=transform_test)
8testloader = torch.utils.data.DataLoader(
9        testset, batch_size=64, shuffle=False, num_workers=2)

Use the same CIFAR-10 datasets as the local training example

1stadle_client.set_termination_function(judge_termination, round_to_exit=20, client=stadle_client)
2stadle_client.set_training_function(train, trainloader, lr=args.lr, epochs=args.lt_epochs, device=device, agent_name=args.agent_name)
3stadle_client.set_cross_validation_function(cross_validate, testloader, device=device)
4stadle_client.set_testing_function(test, testloader)

Pass functions to IntegratedClient for use in internal training loop

1stadle_client.set_bm_obj(model)
2stadle_client.start()

Set the container model for the client, then start the agent FL process

Running Client-Side STADLE Components

After starting the requisite server-side STADLE components, there is one final step that must be run to fully initialize an FL process with STADLE and prepare for agent connections. The component responsible for this is called the admin agent - its role in this case is to send the model structure and information to the persistence server for use in converting between specific model frameworks and the framework-agnostic model representation used by STADLE. The following is example admin agent code for the CIFAR-10 example:

1from stadle import AdminAgent
2from stadle import BaseModelConvFormat
3from stadle.lib.entity.model import BaseModel
4from stadle.lib.util import admin_arg_parser
5
6from vgg import VGG

This section imports the required objects from STADLE, as well as a function for reading command line arguments and the VGG model. The BaseModel object acts as a container for information on the model being trained with STADLE, and is passed to the AdminAgent to be sent to the persistence server.

1base_model = BaseModel("PyTorch-CIFAR10-Model", VGG('VGG16'), BaseModelConvFormat.pytorch_format)

The specific BaseModel object is then created with the VGG16 model structure and information.

1args = admin_arg_parser()
2admin_agent = AdminAgent(config_file=args.config_path, simulation_flag=args.simulation,
3                         aggregator_ip_address=args.ip_address, reg_socket=args.reg_port,
4                         exch_socket=args.exch_port, model_path=args.model_path, base_model=base_model,
5                         agent_running=args.agent_running)
6
7admin_agent.preload()
8admin_agent.initialize()

The command line arguments are parsed and used to create the AdminAgent object, along with the base model. The preload function prepares the base model to be sent (converting to agnostic representation internally) and the initialize function sends the base model information, preparing all of the aggregators to connect to agents by extension.

After the admin agent is run, the main agent client-side code can freely be run. In summary, the order to run components is as follows:

  1. Start persistence server

  2. Start aggregator(s)

  3. Run admin agent (only once)

  4. Run agent(s) - client-side code