diff --git a/docs/infosec-orverview.md b/docs/infosec-orverview.md index 62f24982ec..6775305299 100644 --- a/docs/infosec-orverview.md +++ b/docs/infosec-orverview.md @@ -1,33 +1,41 @@ # InfoSec Overview -## Purpose +_Last updated: 19 May 2025_ + +### Purpose This document provides the information needed when evaluating OpenFL for real world deployment in highly sensitive environments. The target audience is InfoSec reviewers who need detailed information about code contents, communication traffic, and potential exploit vectors. -## Network Connectivity Overview -OpenFL federations use a hub-and-spoke topology between _collaborator_ clients that generate model parameter updates from their data and the _aggregator_ server that combines their training updates into new models [[ref](https://openfl.readthedocs.io/en/latest/about/features_index/taskrunner.html)]. Key details about this functionality are: -* Connections are made using request/response gRPC connections [[ref](https://grpc.io/docs/what-is-grpc/core-concepts/)]. -* The _aggregator_ listens for connections on a single port (usually decided by the experiment admin), and is explicitly defined in the FL plan (f.e. `50051`), so all _collaborators_ must be able to send outgoing traffic to this port. -* All connections are initiated by the _collaborator_, i.e., a `pull` architecture [[ref](https://karlchris.github.io/data-engineering/data-ingestion/push-pull/#pull)]. -* The _collaborator_ does not open any listening sockets. -* Connections are secured using mutually-authenticated TLS [[ref](https://www.cloudflare.com/learning/access-management/what-is-mutual-tls/)]. +### Overview: Network Connectivity +OpenFL federations use a hub-and-spoke topology between `collaborator` clients that generate model parameter updates from their data and the `aggregator` server that combines their training updates into new models[^1]. Key details about this functionality are: + +* Connections are made using request/response gRPC[^2] connections. +* The `aggregator` listens for connections on a single port (usually decided by the experiment admin), and is explicitly defined in the FL plan (f.e. `50051`), so all `collaborator`s must be able to send outgoing traffic to this port. +* All connections are initiated by the `collaborator`, i.e., a [`pull`](https://karlchris.github.io/data-engineering/data-ingestion/push-pull/#pull) architecture. +* The `collaborator` does not open any listening sockets. +* Connections are secured using mTLS[^3]. * Each request response pair is done on a new TLS connection. -* The PKI for federations can be created using the [OpenFL CLI](https://openfl.readthedocs.io/en/latest/about/features_index/taskrunner.html#step-2-configure-the-federation). OpenFL internally leverages Python's cryptography module. The organization hosting the _aggregator_ usually acts as the Certificate Authority (CA) and verifies each identity before signing. -* Currently, the _collaborator_ polls the _aggregator_ at a fixed interval. We have had a request to enable client-side configuration of this interval and hope to support that feature soon. +* The PKI for federations is created using the [`fx aggregator/collaborator certify`](https://openfl.readthedocs.io/en/latest/fx.html) CLI command. OpenFL internally leverages Python's cryptography module. The organization hosting the `aggregator` usually acts as the Certificate Authority (CA) and verifies each identity before signing. +* Currently, the `collaborator` polls the `aggregator` at a fixed interval. We have had a request to enable client-side configuration of this interval and hope to support that feature soon. * Connection timeouts are set to gRPC defaults. -* If the _aggregator_ is not available, the _collaborator_ will retry connections indefinitely. This is currently useful so that we can take the aggregator down for bugfixes without _collaborator_ processes exiting. +* If the `aggregator` is not available, the `collaborator` will retry connections indefinitely. This is currently useful so that we can take the aggregator down for bugfixes without `collaborator` processes exiting. -## Overview of Contents of Network Messages +### Contents of Network Messages Network messages are well defined protobufs which can be found in the following files: -- [aggregator.proto](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto) -- [base.proto](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/base.proto) +- [`aggregator.proto`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto) +- [`base.proto`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/base.proto) Key points about the network messages/protocol: * No executable code is ever sent to the collaborator. All code to be executed is contained within the OpenFL package and the custom FL workspace. The code, along with the FL plan file that specifies the classes and initial parameters to be used, is available for review prior to the FL plans execution. This ensures that all potential operations are understood before they take place. -* The _collaborator_ typically requests the FL tasks to execute from the aggregator via the `GetTasksRequest` message [[ref](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L34)] -* The _aggregator_ reads the FL plan and returns a `GetTasksResponse` [[ref](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L45)] which includes metadata (`Tasks`) [[ref](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L38)] about the Python functions to be invoked by the collaborator (the code being installed locally as part of a pre-distributed workspace bundle) -* The _collaborator_ then uses its TaskRunner framework to execute the FL tasks on the locally available data, producing output tensors such as model weights or metrics -* During task execution, the _collaborator_ may additionally request tensors from the aggregator via the `GetAggregatedTensor` RPC method [[ref](https://openfl.readthedocs.io/en/latest/reference/_autosummary/openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.html#openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.GetAggregatedTensor)] -* Upon task completion, the _collaborator_ transmits the results by emitting a `SendLocalTaskResults` call [[ref](https://openfl.readthedocs.io/en/latest/reference/_autosummary/openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.html#openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.SendLocalTaskResults)] which contains `NamedTensor` [[ref](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/base.proto#L11)] objects that encode model weight updates or ML metrics such as loss or accuracy (among others). - -## Testing a Collaborator -There is a "no-op" workspace template in OpenFL (available in versions `>=1.9`) which can be used to test the network connection between the _aggregator_ and each _collaborator_ without performing any computational task. More details can be found [here](https://github.com/securefederatedai/openfl/tree/develop/openfl-workspace/no-op#overview). +* The `collaborator` typically requests the FL tasks to execute from the aggregator via a [`GetTasksRequest`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L34) message. +* The `aggregator` based on the FL plan, returns a [`GetTasksResponse`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L45) which includes [`Tasks`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/aggregator.proto#L38) - metadata about the Python functions to be invoked by the collaborator. All code is available locally to each collaborator as part of a pre-distributed workspace bundle. +* The `collaborator` then uses its TaskRunner framework to execute the FL tasks on the locally available data, producing output tensors such as model weights or metrics. +* During task execution, the `collaborator` may require certain tensors for task execution that are not available locally. For example, a federated training task requires globally averaged model weights from the `aggregator`. Collaborators gather a list of tensor keys that need to be fetched from the aggregator and download them via the [`GetAggregatedTensors`](https://openfl.readthedocs.io/en/latest/reference/_autosummary/openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.html#openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.GetAggregatedTensor) RPC method. +* Upon task completion, the `collaborator` transmits the results by emitting a [`SendLocalTaskResults`](https://openfl.readthedocs.io/en/latest/reference/_autosummary/openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.html#openfl.transport.grpc.aggregator_server.AggregatorGRPCServer.SendLocalTaskResults) RPC method which contains [`NamedTensor`](https://github.com/securefederatedai/openfl/blob/develop/openfl/protocols/base.proto#L11) objects that encode results (like model weight updates or metrics such as loss or accuracy). + +### Testing a Collaborator +There is a "no-op" workspace template in OpenFL (available in versions `>=1.9`) which can be used to test the network connection between the `aggregator` and each `collaborator` without performing any computational task. More details can be found [here](https://github.com/securefederatedai/openfl/tree/develop/openfl-workspace/no-op#overview). + + +[^1]: [OpenFL TaskRunner Overview](https://openfl.readthedocs.io/en/latest/about/features_index/taskrunner.html) +[^2]: [gRPC Overview](https://grpc.io/docs/what-is-grpc/core-concepts/) +[^3]: [mTLS Overview](https://www.cloudflare.com/learning/access-management/what-is-mutual-tls/) \ No newline at end of file diff --git a/openfl-workspace/gandlf_seg_test/src/dataloader.py b/openfl-workspace/gandlf_seg_test/src/dataloader.py index ce53a006ad..f95eaf14d1 100644 --- a/openfl-workspace/gandlf_seg_test/src/dataloader.py +++ b/openfl-workspace/gandlf_seg_test/src/dataloader.py @@ -39,4 +39,4 @@ def get_feature_shape(self): """ # Define a fixed feature shape for this specific application # Use standard 3D patch size for medical imaging segmentation - return self.feature_shape \ No newline at end of file + return self.feature_shape diff --git a/openfl/callbacks/secure_aggregation.py b/openfl/callbacks/secure_aggregation.py index 4fade8d79a..7fc66171b9 100644 --- a/openfl/callbacks/secure_aggregation.py +++ b/openfl/callbacks/secure_aggregation.py @@ -295,5 +295,6 @@ def _fetch_from_aggregator(self, key_name): Returns: bytes: The aggregated tensor data in bytes. """ - tensor = self.client.get_aggregated_tensor(key_name, -1, False, ("secagg",), True) + key = TensorKey(key_name, self.name, -1, False, ("secagg",)) + tensor = self.client.get_aggregated_tensors([key], require_lossless=True)[0] return json.loads(tensor.data_bytes) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index f1c07cf0f4..f0e260fe6a 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -6,7 +6,6 @@ import json import logging import queue -import time from threading import Lock from typing import List, Optional @@ -17,7 +16,7 @@ from openfl.databases import PersistentTensorDB, TensorDB from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec -from openfl.protocols import base_pb2, utils +from openfl.protocols import utils from openfl.protocols.base_pb2 import NamedTensor from openfl.utilities import TaskResultKey, TensorKey, change_tags @@ -84,7 +83,6 @@ def __init__( single_col_cert_common_name=None, compression_pipeline=None, db_store_rounds=1, - initial_tensor_dict=None, log_memory_usage=False, write_logs=False, callbacks: Optional[List] = [], @@ -92,31 +90,6 @@ def __init__( persistent_db_path=None, secure_aggregation=False, ): - """Initializes the Aggregator. - - Args: - aggregator_uuid (int): Aggregation ID. - federation_uuid (str): Federation ID. - authorized_cols (list of str): The list of IDs of enrolled - collaborators. - init_state_path (str): The location of the initial weight file. - best_state_path (str): The file location to store the weight of - the best model. - last_state_path (str): The file location to store the latest - weight. - assigner: Assigner object. - straggler_handling_policy (optional): Straggler handling policy. - rounds_to_train (int, optional): Number of rounds to train. - Defaults to 256. - single_col_cert_common_name (str, optional): Common name for single - collaborator certificate. Defaults to None. - compression_pipeline (optional): Compression pipeline. Defaults to - NoCompressionPipeline. - db_store_rounds (int, optional): Rounds to store in TensorDB. - Defaults to 1. - initial_tensor_dict (dict, optional): Initial tensor dictionary. - callbacks: List of callbacks to be used during the experiment. - """ self.round_number = 0 self.next_model_round_number = 0 @@ -206,20 +179,10 @@ def __init__( last_state_path=self.last_state_path, ) - if initial_tensor_dict: - self._load_initial_tensors_from_dict(initial_tensor_dict) - self.model = utils.construct_model_proto( - tensor_dict=initial_tensor_dict, - round_number=0, - tensor_pipe=self.compression_pipeline, - ) - else: - if self.connector: - # The model definition will be handled by the respective framework - self.model = {} - else: - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) - self._load_initial_tensors() # keys are TensorKeys + self.model = {} + if not self.connector: + self.model = utils.load_proto(self.init_state_path) + self._load_initial_tensors() # keys are TensorKeys self._secure_aggregation_enabled = secure_aggregation if self._secure_aggregation_enabled: @@ -338,23 +301,6 @@ def _load_initial_tensors(self): self.tensor_db.cache_tensor(tensor_key_dict) logger.debug("This is the initial tensor_db: %s", self.tensor_db) - def _load_initial_tensors_from_dict(self, tensor_dict): - """Load all of the tensors required to begin federated learning. - - Required tensors are: \ - 1. Initial model. - - Returns: - None - """ - tensor_key_dict = { - TensorKey(k, self.uuid, self.round_number, False, ("model",)): v - for k, v in tensor_dict.items() - } - # all initial model tensors are loaded here - self.tensor_db.cache_tensor(tensor_key_dict) - logger.debug("This is the initial tensor_db: %s", self.tensor_db) - def _save_model(self, round_number, file_path): """Save the best or latest model. @@ -486,10 +432,7 @@ def get_tasks(self, collaborator_name): # first, if it is time to quit, inform the collaborator if self._time_to_quit(): - logger.info( - "Sending signal to collaborator %s to shutdown...", - collaborator_name, - ) + logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name) self.quit_job_sent_to.append(collaborator_name) tasks = None @@ -611,74 +554,28 @@ def get_aggregated_tensor( Raises: ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}. """ - if "compressed" in tags or require_lossless: - compress_lossless = True - else: - compress_lossless = False - if not self._check_tags(tags, requested_by): logger.error( - "Tag check failed: unauthorized tags detected. Only '%s' is allowed.", requested_by + "Collaborator `%s` is not allowed to fetch tensor with tags `%s`.", + requested_by, + tags, ) return NamedTensor() - # TODO the TensorDB doesn't support compressed data yet. - # The returned tensor will be recompressed anyway. + # We simply remove compression-related tags because serializer adds them. if "compressed" in tags: tags = change_tags(tags, remove_field="compressed") if "lossy_compressed" in tags: tags = change_tags(tags, remove_field="lossy_compressed") + # Fetch tensor tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, tags) - tensor_name, origin, round_number, report, tags = tensor_key - - if "aggregated" in tags and "delta" in tags and round_number != 0: - agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("aggregated",)) - else: - agg_tensor_key = tensor_key - - nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key) - - start_retrieving_time = time.time() - while nparray is None: - logger.debug("Waiting for tensor_key %s", agg_tensor_key) - time.sleep(5) - nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key) - if (time.time() - start_retrieving_time) > 60: - break - + nparray = self.tensor_db.get_tensor_from_cache(tensor_key) if nparray is None: - raise ValueError(f"Aggregator does not have an aggregated tensor for {tensor_key}") - - # quite a bit happens in here, including compression, delta handling, - # etc... - # we might want to cache these as well - named_tensor = self._nparray_to_named_tensor( - agg_tensor_key, nparray, send_model_deltas=True, compress_lossless=compress_lossless - ) - - return named_tensor - - def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, compress_lossless): - """Construct the NamedTensor Protobuf. - - Also includes logic to create delta, compress tensors with the - TensorCodec, etc. + raise ValueError(f"Aggregator does not have `{tensor_key}`") - Args: - tensor_key (TensorKey): Tensor key. - nparray (np.array): Numpy array. - send_model_deltas (bool): Whether to send model deltas. - compress_lossless (bool): Whether to compress lossless. - - Returns: - tensor_key (TensorKey): Tensor key. - nparray (np.array): Numpy array. - - """ - tensor_name, origin, round_number, report, tags = tensor_key - # Secure aggregation setup tensor. - if "secagg" in tags: + # Serialize (and compress) the tensor + if "secagg" in tensor_key.tags: import numpy as np class NumpyEncoder(json.JSONEncoder): @@ -695,43 +592,10 @@ def default(self, obj): ) return named_tensor - # if we have an aggregated tensor, we can make a delta - if "aggregated" in tags and send_model_deltas: - # Should get the pretrained model to create the delta. If training - # has happened, Model should already be stored in the TensorDB - model_tk = TensorKey(tensor_name, origin, round_number - 1, report, ("model",)) - - model_nparray = self.tensor_db.get_tensor_from_cache(model_tk) - - assert model_nparray is not None, ( - "The original model layer should be present if the latest " - "aggregated model is present" - ) - delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta( - tensor_key, nparray, model_nparray - ) - delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress( - delta_tensor_key, delta_nparray, lossless=compress_lossless - ) - named_tensor = utils.construct_named_tensor( - delta_comp_tensor_key, - delta_comp_nparray, - metadata, - lossless=compress_lossless, - ) - - else: - # Assume every other tensor requires lossless compression - compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress( - tensor_key, nparray, require_lossless=True - ) - named_tensor = utils.construct_named_tensor( - compressed_tensor_key, - compressed_nparray, - metadata, - lossless=compress_lossless, - ) + named_tensor = utils.serialize_tensor( + tensor_key, nparray, self.tensor_codec, lossless=require_lossless + ) return named_tensor def _collaborator_task_completed(self, collaborator, task_name, round_num): @@ -830,45 +694,47 @@ def process_task_results( self._is_collaborator_done(collaborator_name, round_number) self._end_of_round_with_stragglers_check() - task_key = TaskResultKey(task_name, collaborator_name, round_number) - - # we mustn't have results already if self._collaborator_task_completed(collaborator_name, task_name, round_number): logger.warning( - f"Aggregator already has task results from collaborator {collaborator_name}" - f" for task {task_key}" + f"Aggregator already has task results from collaborator {collaborator_name} " + f"for task {task_name} in round {round_number}. Ignoring..." ) return - # By giving task_key it's own weight, we can support different - # training/validation weights - # As well as eventually supporting weights that change by round - # (if more data is added) + # Record collaborator individual weightage/contribution for federated averaging + task_key = TaskResultKey(task_name, collaborator_name, round_number) self.collaborator_task_weight[task_key] = data_size - # initialize the list of tensors that go with this task - # Setting these incrementally is leading to missing values + # Process named tensors task_results = [] - + result_tensor_dict = {} for named_tensor in named_tensors: - # quite a bit happens in here, including decompression, delta - # handling, etc... - tensor_key, value = self._process_named_tensor(named_tensor, collaborator_name) + # Deserialize + tensor_key, nparray = utils.deserialize_tensor(named_tensor, self.tensor_codec) + + # Update origin/tags + updated_tags = change_tags(tensor_key.tags, add_field=collaborator_name) + tensor_key = tensor_key._replace(origin=self.uuid, tags=updated_tags) + + # Record + result_tensor_dict[tensor_key] = nparray + task_results.append(tensor_key) if "metric" in tensor_key.tags: - # Caution: This schema must be followed. It is also used in - # gRPC message streams for director/envoy. + assert nparray.ndim == 0, ( + f"Expected metric to be a scalar, got shape {nparray.shape}" + ) metrics = { "round": round_number, "metric_origin": collaborator_name, "task_name": task_name, "metric_name": tensor_key.tensor_name, - "metric_value": float(value), + "metric_value": float(nparray), } self.metric_queue.put(metrics) - task_results.append(tensor_key) - + # Store results in TensorDB + self.tensor_db.cache_tensor(result_tensor_dict) self.collaborator_tasks_results[task_key] = task_results # Check if collaborator or round is done. @@ -895,97 +761,6 @@ def _end_of_round_with_stragglers_check(self): logger.warning(f"Identified stragglers: {self.stragglers}") self._end_of_round_check() - def _process_named_tensor(self, named_tensor, collaborator_name): - """Extract the named tensor fields. - - Performs decompression, delta computation, and inserts results into - TensorDB. - - Args: - named_tensor (protobuf NamedTensor): Named tensor. - protobuf that will be extracted from and processed - collaborator_name (str): Collaborator name. - Collaborator name is needed for proper tagging of resulting - tensorkeys. - - Returns: - tensor_key (TensorKey): Tensor key. - The tensorkey extracted from the protobuf. - nparray (np.array): Numpy array. - The numpy array associated with the returned tensorkey. - """ - raw_bytes = named_tensor.data_bytes - metadata = [ - { - "int_to_float": proto.int_to_float, - "int_list": proto.int_list, - "bool_list": proto.bool_list, - } - for proto in named_tensor.transformer_metadata - ] - # The tensor has already been transferred to aggregator, - # so the newly constructed tensor should have the aggregator origin - tensor_key = TensorKey( - named_tensor.name, - self.uuid, - named_tensor.round_number, - named_tensor.report, - tuple(named_tensor.tags), - ) - tensor_name, origin, round_number, report, tags = tensor_key - - assert "compressed" in tags or "lossy_compressed" in tags, ( - f"Named tensor {tensor_key} is not compressed" - ) - if "compressed" in tags: - dec_tk, decompressed_nparray = self.tensor_codec.decompress( - tensor_key, - data=raw_bytes, - transformer_metadata=metadata, - require_lossless=True, - ) - dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk - # Need to add the collaborator tag to the resulting tensor - new_tags = change_tags(dec_tags, add_field=collaborator_name) - - # layer.agg.n.trained.delta.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) - if "lossy_compressed" in tags: - dec_tk, decompressed_nparray = self.tensor_codec.decompress( - tensor_key, - data=raw_bytes, - transformer_metadata=metadata, - require_lossless=False, - ) - dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk - new_tags = change_tags(dec_tags, add_field=collaborator_name) - # layer.agg.n.trained.delta.lossy_decompressed.col_i - decompressed_tensor_key = TensorKey( - dec_name, dec_origin, dec_round_num, dec_report, new_tags - ) - - if "delta" in tags: - base_model_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("model",)) - base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tensor_key) - if base_model_nparray is None: - raise ValueError(f"Base model {base_model_tensor_key} not present in TensorDB") - final_tensor_key, final_nparray = self.tensor_codec.apply_delta( - decompressed_tensor_key, - decompressed_nparray, - base_model_nparray, - ) - else: - final_tensor_key = decompressed_tensor_key - final_nparray = decompressed_nparray - - assert final_nparray is not None, f"Could not create tensorkey {final_tensor_key}" - self.tensor_db.cache_tensor({final_tensor_key: final_nparray}) - logger.debug("Created TensorKey: %s", final_tensor_key) - - return final_tensor_key, final_nparray - def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results): """Prepare aggregated tensorkey tags. @@ -996,82 +771,13 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result report (bool): Whether to report. agg_results (np.array): Aggregated results. """ - # The aggregated tensorkey tags should have the form of - # 'trained' or 'trained.lossy_decompressed' - # They need to be relabeled to 'aggregated' and - # reinserted. Then delta performed, compressed, etc. - # then reinserted to TensorDB with 'model' tag - - # First insert the aggregated model layer with the - # correct tensorkey agg_tag_tk = TensorKey(tensor_name, origin, round_number + 1, report, ("aggregated",)) self.tensor_db.cache_tensor({agg_tag_tk: agg_results}) - # Create delta and save it in TensorDB - base_model_tk = TensorKey(tensor_name, origin, round_number, report, ("model",)) - base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk) - if base_model_nparray is not None and self.use_delta_updates: - delta_tk, delta_nparray = self.tensor_codec.generate_delta( - agg_tag_tk, agg_results, base_model_nparray - ) - else: - # This condition is possible for base model - # optimizer states (i.e. Adam/iter:0, SGD, etc.) - # These values couldn't be present for the base - # model because no training occurs on the aggregator - delta_tk, delta_nparray = agg_tag_tk, agg_results - - # Compress lossless/lossy - compressed_delta_tk, compressed_delta_nparray, metadata = self.tensor_codec.compress( - delta_tk, delta_nparray - ) - - # TODO extend the TensorDB so that compressed data is - # supported. Once that is in place - # the compressed delta can just be stored here instead - # of recreating it for every request - - # Decompress lossless/lossy - decompressed_delta_tk, decompressed_delta_nparray = self.tensor_codec.decompress( - compressed_delta_tk, compressed_delta_nparray, metadata - ) - - self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray}) - - # Apply delta (unless delta couldn't be created) - if base_model_nparray is not None and self.use_delta_updates: - logger.debug("Applying delta for layer %s", decompressed_delta_tk[0]) - new_model_tk, new_model_nparray = self.tensor_codec.apply_delta( - decompressed_delta_tk, - decompressed_delta_nparray, - base_model_nparray, - ) - else: - new_model_tk, new_model_nparray = ( - decompressed_delta_tk, - decompressed_delta_nparray, - ) - - # Now that the model has been compressed/decompressed - # with delta operations, - # Relabel the tags to 'model' - ( - new_model_tensor_name, - new_model_origin, - new_model_round_number, - new_model_report, - new_model_tags, - ) = new_model_tk - final_model_tk = TensorKey( - new_model_tensor_name, - new_model_origin, - new_model_round_number, - new_model_report, - ("model",), - ) - self.next_model_round_number = new_model_round_number - # Finally, cache the updated model tensor - self.tensor_db.cache_tensor({final_model_tk: new_model_nparray}) + # Relabel the tags to 'model' and cache the updated model tensor + final_model_tk = agg_tag_tk._replace(tags=("model",)) + self.next_model_round_number = final_model_tk.round_number + self.tensor_db.cache_tensor({final_model_tk: agg_results}) def _compute_validation_related_task_metrics(self, task_name) -> dict: """Compute all validation related metrics. @@ -1091,7 +797,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: ) # Leave out straggler for the round even if they've partially # completed given tasks - collaborators_for_task = [] collaborators_for_task = [ c for c in all_collaborators_for_task if c in self.collaborators_done ] @@ -1160,14 +865,10 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: if not self.assigner.is_task_group_evaluation(): logger.info( f"Round {round_number}: saved the best model with score " - "{agg_results:f}" + f"{agg_results:f}" ) self._save_model(round_number, self.best_state_path) - else: - logger.info( - f"Round {round_number}: best score observed {agg_results:f} " - "(model not saved in evaluation mode)" - ) + if "trained" in tags: self._prepare_trained(tensor_name, origin, round_number, report, agg_results) diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 7f66805753..bd3b22df34 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -190,6 +190,7 @@ def run(self): # Run tasks logs = {} for task in tasks: + logger.info("Task: `%s`", task.name) metrics = self.do_task(task, round_num) logs.update(metrics) @@ -205,58 +206,53 @@ def do_task(self, task, round_number) -> dict: """Perform the specified task. Args: - task (list_of_str): List of tasks. - round_number (int): Actual round number. + task: Task proto. + round_number (int): Round number. Returns: A dictionary of reportable metrics of the current collaborator for the task. """ - # map this task to an actual function name and kwargs - if isinstance(task, str): - task_name = task - else: - task_name = task.name - func_name = self.task_config[task_name]["function"] - kwargs = self.task_config[task_name]["kwargs"] + func_name = self.task_config[task.name]["function"] + kwargs = self.task_config[task.name]["kwargs"] # this would return a list of what tensors we require as TensorKeys - required_tensorkeys_relative = self.task_runner.get_required_tensorkeys_for_function( - func_name, **kwargs - ) - # models actually return "relative" tensorkeys of (name, LOCAL|GLOBAL, - # round_offset) - # so we need to update these keys to their "absolute values" - required_tensorkeys = [] - for ( - tname, - origin, - rnd_num, - report, - tags, - ) in required_tensorkeys_relative: - if origin == "GLOBAL": - origin = self.aggregator_uuid - else: - origin = self.collaborator_name + # round_offset) so we need to update these keys to their "absolute values" + tensor_keys = self.task_runner.get_required_tensorkeys_for_function(func_name, **kwargs) + global_keys, local_keys = [], [] + for tensor_key in tensor_keys: + if tensor_key.origin == "GLOBAL": + tensor_key = tensor_key._replace( + origin=self.aggregator_uuid, round_number=round_number + ) + global_keys.append(tensor_key) - # rnd_num is the relative round. So if rnd_num is -1, get the - # tensor from the previous round - required_tensorkeys.append( - TensorKey(tname, origin, rnd_num + round_number, report, tags) - ) + elif tensor_key.origin == "LOCAL": + tensor_key = tensor_key._replace( + origin=self.collaborator_name, round_number=round_number + ) + local_keys.append(tensor_key) + + # Prepare input tensor dict for this task + self.fetch_tensors_from_aggregator(global_keys) + input_tensor_dict = {} + for tk in local_keys: + value = self.tensor_db.get_tensor_from_cache(tk) + if value is None: + raise ValueError(f"Value corresponding to local tensor `{tk}` not found.") + input_tensor_dict[tk.tensor_name] = value + + for tk in global_keys: + value = self.tensor_db.get_tensor_from_cache(tk) + if value is None: + raise ValueError(f"Value corresponding to global tensor `{tk}` not found.") + input_tensor_dict[tk.tensor_name] = value + + self.callbacks.on_task_begin(task.name, round_number) - # print('Required tensorkeys = {}'.format( - # [tk[0] for tk in required_tensorkeys])) - input_tensor_dict = { - k.tensor_name: self.get_data_for_tensorkey(k) for k in required_tensorkeys - } - self.callbacks.on_task_begin(task_name, round_number) # now we have whatever the model needs to do the task # Tasks are defined as methods of TaskRunner func = getattr(self.task_runner, func_name) - logger.debug("Using TaskRunner subclassing API") - global_output_tensor_dict, local_output_tensor_dict = func( col_name=self.collaborator_name, round_num=round_number, @@ -264,7 +260,7 @@ def do_task(self, task, round_number) -> dict: **kwargs, ) - self.callbacks.on_task_end(task_name, round_number) + self.callbacks.on_task_end(task.name, round_number) # If secure aggregation is enabled, add masks to the dict to be shared # with the aggregator. @@ -277,142 +273,35 @@ def do_task(self, task, round_number) -> dict: # send the results for this tasks; delta and compression will occur in # this function - metrics = self.send_task_results(global_output_tensor_dict, round_number, task_name) + metrics = self.send_task_results(global_output_tensor_dict, round_number, task.name) return metrics - def get_data_for_tensorkey(self, tensor_key): - """Resolve the tensor corresponding to the requested tensorkey. + def fetch_tensors_from_aggregator(self, tensor_keys: List[TensorKey]): + """Fetches tensors from the aggregator and stores them locally. - Args: - tensor_key (namedtuple): Tensorkey that will be resolved locally or - remotely. May be the product of other tensors. - - Returns: - nparray: The decompressed tensor associated with the requested - tensor key. - """ - # try to get from the store - tensor_name, origin, round_number, report, tags = tensor_key - logger.debug("Attempting to retrieve tensor %s from local store", tensor_key) - nparray = self.tensor_db.get_tensor_from_cache(tensor_key) - - # if None and origin is our client, request it from the client - if nparray is None: - if origin == self.collaborator_name: - logger.info( - f"Attempting to find locally stored {tensor_name} tensor from prior round..." - ) - prior_round = round_number - 1 - while prior_round >= 0: - nparray = self.tensor_db.get_tensor_from_cache( - TensorKey(tensor_name, origin, prior_round, report, tags) - ) - if nparray is not None: - logger.debug( - f"Found tensor {tensor_name} in local TensorDB for round {prior_round}" - ) - return nparray - prior_round -= 1 - logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...") - # Determine whether there are additional compression related - # dependencies. - # Typically, dependencies are only relevant to model layers - tensor_dependencies = self.tensor_codec.find_dependencies( - tensor_key, self.use_delta_updates - ) - logger.debug( - "Unable to get tensor from local store..." - "attempting to retrieve from client len tensor_dependencies" - f" tensor_key {tensor_key}" - ) - if len(tensor_dependencies) > 0: - # Resolve dependencies - # tensor_dependencies[0] corresponds to the prior version - # of the model. - # If it exists locally, should pull the remote delta because - # this is the least costly path - prior_model_layer = self.tensor_db.get_tensor_from_cache(tensor_dependencies[0]) - if prior_model_layer is not None: - uncompressed_delta = self.get_aggregated_tensor_from_aggregator( - tensor_dependencies[1] - ) - new_model_tk, nparray = self.tensor_codec.apply_delta( - tensor_dependencies[1], - uncompressed_delta, - prior_model_layer, - creates_model=True, - ) - self.tensor_db.cache_tensor({new_model_tk: nparray}) - else: - logger.info( - "Could not find previous model layer.Fetching latest layer from aggregator" - ) - # The original model tensor should be fetched from aggregator - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) - elif "model" in tags: - # Pulling the model for the first time - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) - else: - # we should try fetching the tensor from aggregator - tensor_name, origin, round_number, report, tags = tensor_key - tags = (self.collaborator_name,) + tags - tensor_key = (tensor_name, origin, round_number, report, tags) - logger.info( - "Could not find previous model layer." - f"Fetching latest layer from aggregator {tensor_key}" - ) - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) - else: - logger.debug("Found tensor %s in local TensorDB", tensor_key) - - return nparray - - def get_aggregated_tensor_from_aggregator(self, tensor_key, require_lossless=False): - """ - Return the decompressed tensor associated with the requested tensor key. - - If the key requests a compressed tensor (in the tag), the tensor will - be decompressed before returning. - If the key specifies an uncompressed tensor (or just omits a compressed - tag), the decompression operation will be skipped. + This function checks if the tensors are already cached in the local database + and fetches them from the aggregator if not. The fetched tensors are then + cached in the local database. Args: - tensor_key (namedtuple): The requested tensor. - require_lossless (bool): Should compression of the tensor be - allowed in flight? For the initial model, it may affect - convergence to apply lossy compression. And metrics shouldn't - be compressed either. - - Returns: - nparray : The decompressed tensor associated with the requested - tensor key. + tensor_keys (list): List of TensorKeys to fetch. """ - tensor_name, origin, round_number, report, tags = tensor_key - - logger.debug("Requesting aggregated tensor %s", tensor_key) - tensor = self.client.get_aggregated_tensor( - tensor_name, - round_number, - report, - tags, - require_lossless, + tensor_dict = {} + tensor_keys = list( + filter(lambda k: self.tensor_db.get_tensor_from_cache(k) is None, tensor_keys) ) + if len(tensor_keys) > 0: + logger.info("Fetching %d tensors from the aggregator", len(tensor_keys)) + named_tensors = self.client.get_aggregated_tensors(tensor_keys, require_lossless=True) - # this translates to a numpy array and includes decompression, as - # necessary - nparray = self.named_tensor_to_nparray(tensor) + # Deserialize tensors and mark them as coming from the aggregator. + for tensor_key, named_tensor in zip(tensor_keys, named_tensors): + tensor_key, nparray = utils.deserialize_tensor(named_tensor, self.tensor_codec) + tensor_key = tensor_key._replace(origin=self.aggregator_uuid) + tensor_dict[tensor_key] = nparray - # cache this tensor - self.tensor_db.cache_tensor({tensor_key: nparray}) - - return nparray + self.tensor_db.cache_tensor(tensor_dict) def send_task_results(self, tensor_dict, round_number, task_name) -> dict: """Send task results to the aggregator. @@ -425,8 +314,6 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict: Returns: A dictionary of reportable metrics of the current collaborator for the task. """ - named_tensors = [self.nparray_to_named_tensor(k, v) for k, v in tensor_dict.items()] - # for general tasks, there may be no notion of data size to send. # But that raises the question how to properly aggregate results. @@ -449,6 +336,12 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict: value = float(tensor_dict[tensor]) metrics.update({f"{self.collaborator_name}/{task_name}/{tensor_name}": value}) + # Serialize tensors to be sent to the aggregator + named_tensors = [ + utils.serialize_tensor(k, v, self.tensor_codec, lossless=True) + for k, v in tensor_dict.items() + ] + self.client.send_local_task_results( round_number, task_name, @@ -458,111 +351,6 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict: return metrics - def nparray_to_named_tensor(self, tensor_key, nparray): - """Construct the NamedTensor Protobuf. - - Includes logic to create delta, compress tensors with the TensorCodec, - etc. - - Args: - tensor_key (namedtuple): Tensorkey that will be resolved locally or - remotely. May be the product of other tensors. - nparray: The decompressed tensor associated with the requested - tensor key. - - Returns: - named_tensor (protobuf) : The tensor constructed from the nparray. - """ - # if we have an aggregated tensor, we can make a delta - tensor_name, origin, round_number, report, tags = tensor_key - if "trained" in tags and self.use_delta_updates: - # Should get the pretrained model to create the delta. If training - # has happened, - # Model should already be stored in the TensorDB - model_nparray = self.tensor_db.get_tensor_from_cache( - TensorKey(tensor_name, origin, round_number, report, ("model",)) - ) - - # The original model will not be present for the optimizer on the - # first round. - if model_nparray is not None: - delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta( - tensor_key, nparray, model_nparray - ) - delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress( - delta_tensor_key, delta_nparray - ) - - named_tensor = utils.construct_named_tensor( - delta_comp_tensor_key, - delta_comp_nparray, - metadata, - lossless=False, - ) - return named_tensor - - # Assume every other tensor requires lossless compression - compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress( - tensor_key, nparray, require_lossless=True - ) - named_tensor = utils.construct_named_tensor( - compressed_tensor_key, compressed_nparray, metadata, lossless=True - ) - - return named_tensor - - def named_tensor_to_nparray(self, named_tensor): - """Convert named tensor to a numpy array. - - Args: - named_tensor (protobuf): The tensor to convert to nparray. - - Returns: - decompressed_nparray (nparray): The nparray converted. - """ - # do the stuff we do now for decompression and frombuffer and stuff - # This should probably be moved back to protoutils - raw_bytes = named_tensor.data_bytes - metadata = [ - { - "int_to_float": proto.int_to_float, - "int_list": proto.int_list, - "bool_list": proto.bool_list, - } - for proto in named_tensor.transformer_metadata - ] - # The tensor has already been transferred to collaborator, so - # the newly constructed tensor should have the collaborator origin - tensor_key = TensorKey( - named_tensor.name, - self.collaborator_name, - named_tensor.round_number, - named_tensor.report, - tuple(named_tensor.tags), - ) - *_, tags = tensor_key - if "compressed" in tags: - decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress( - tensor_key, - data=raw_bytes, - transformer_metadata=metadata, - require_lossless=True, - ) - elif "lossy_compressed" in tags: - decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress( - tensor_key, data=raw_bytes, transformer_metadata=metadata - ) - else: - # There could be a case where the compression pipeline is bypassed - # entirely - logger.warning("Bypassing tensor codec...") - decompressed_tensor_key = tensor_key - decompressed_nparray = raw_bytes - - self.tensor_db.cache_tensor({decompressed_tensor_key: decompressed_nparray}) - - return decompressed_nparray - def _apply_masks( self, tensor_dict, diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index a324934e76..9e0e0a96d3 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -364,16 +364,11 @@ def get_tasks(self): tasks[task]["aggregation_type"] = aggregation_type return tasks - def get_aggregator(self, tensor_dict=None): + def get_aggregator(self): """Get federation aggregator. This method retrieves the federation aggregator. If the aggregator - does not exist, it is built using the configuration settings and the - provided tensor dictionary. - - Args: - tensor_dict (dict, optional): The initial tensor dictionary to use - when building the aggregator. Defaults to None. + does not exist, it is built using the configuration settings. Returns: self.aggregator_ (Aggregator): The federation aggregator. @@ -401,7 +396,7 @@ def get_aggregator(self, tensor_dict=None): # TODO: Load callbacks from plan. if self.aggregator_ is None: - self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict) + self.aggregator_ = Plan.build(**defaults) return self.aggregator_ diff --git a/openfl/protocols/aggregator.proto b/openfl/protocols/aggregator.proto index 08cc381765..05682ff580 100644 --- a/openfl/protocols/aggregator.proto +++ b/openfl/protocols/aggregator.proto @@ -11,7 +11,7 @@ import "openfl/protocols/base.proto"; service Aggregator { rpc Ping(PingRequest) returns (PingResponse) {} rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {} - rpc GetAggregatedTensor(GetAggregatedTensorRequest) returns (GetAggregatedTensorResponse) {} + rpc GetAggregatedTensors(GetAggregatedTensorsRequest) returns (GetAggregatedTensorsResponse) {} rpc SendLocalTaskResults(stream DataStream) returns (SendLocalTaskResultsResponse) {} rpc InteropRelay(InteropMessage) returns (InteropMessage) {} } @@ -42,6 +42,14 @@ message Task { bool apply_local = 4; } +message TensorSpec { + string tensor_name = 1; + int32 round_number = 2; + bool report = 3; + repeated string tags = 4; + bool require_lossless = 5; +} + message GetTasksResponse { MessageHeader header = 1; int32 round_number = 2; @@ -51,20 +59,14 @@ message GetTasksResponse { bool quit = 5; // these three are exclusive } -message GetAggregatedTensorRequest { +message GetAggregatedTensorsRequest { MessageHeader header = 1; - string tensor_name = 2; - int32 round_number = 3; - bool report = 4; - repeated string tags = 5; - bool require_lossless = 6; + repeated TensorSpec tensor_specs = 2; } -// we'll actually send this as a data stream -message GetAggregatedTensorResponse { +message GetAggregatedTensorsResponse { MessageHeader header = 1; - int32 round_number = 2; - NamedTensor tensor = 3; + repeated NamedTensor tensors = 2; } // we'll actually send this as a data stream diff --git a/openfl/protocols/utils.py b/openfl/protocols/utils.py index d19cd594f0..5fe1eda62b 100644 --- a/openfl/protocols/utils.py +++ b/openfl/protocols/utils.py @@ -352,3 +352,77 @@ def get_headers(context) -> dict: values are the corresponding header values. """ return {header[0]: header[1] for header in context.invocation_metadata()} + + +def serialize_tensor(tensor_key, nparray, tensor_codec, lossless=True): + """Serialize the tensor. + + This function also performs compression. + + Args: + tensor_key (namedtuple): A TensorKey. + nparray: A NumPy array associated with the requested + tensor key. + tensor_codec: The codec to use for compression. + lossless: A flag indicating whether to use lossless compression. + + Returns: + named_tensor (protobuf) : The tensor constructed from the nparray. + """ + tensor_key, nparray, metadata = tensor_codec.compress( + tensor_key, + nparray, + lossless, + ) + named_tensor = construct_named_tensor( + tensor_key, + nparray, + metadata, + lossless, + ) + return named_tensor + + +def deserialize_tensor(named_tensor, tensor_codec): + """Deserialize a `NamedTensor` to a numpy array. + + This function also performs decompresssion. Whether or not the + decompression is lossless is determined by the `lossless` field + of the `NamedTensor` protobuf. + + Args: + named_tensor (protobuf): The tensor to convert to nparray. + tensor_codec: The codec to use for decompression. + + Returns: + A tuple (TensorKey, nparray), where the `origin` field of the + `TensorKey` is `None`. The `origin` field must be populated + later, as it is not known at this point. + """ + metadata = [ + { + "int_to_float": proto.int_to_float, + "int_list": proto.int_list, + "bool_list": proto.bool_list, + } + for proto in named_tensor.transformer_metadata + ] + + # Deserialization happens on the receiving end. + # Origin of this tensor is populated later. + tensor_key = TensorKey( + named_tensor.name, + None, + named_tensor.round_number, + named_tensor.report, + tuple(named_tensor.tags), + ) + + tensor_key, nparray = tensor_codec.decompress( + tensor_key, + data=named_tensor.data_bytes, + transformer_metadata=metadata, + require_lossless=named_tensor.lossless, + ) + + return tensor_key, nparray diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index a91110b518..cb4a1b18f5 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -6,11 +6,11 @@ import logging import time -from typing import Optional, Tuple +from typing import List, Optional, Tuple import grpc -from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils +from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, base_pb2, utils from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel logger = logging.getLogger(__name__) @@ -341,27 +341,20 @@ def get_tasks(self): @_resend_data_on_reconnection @_atomic_connection - def get_aggregated_tensor( + def get_aggregated_tensors( self, - tensor_name, - round_number, - report, - tags, - require_lossless, - ): + tensor_keys, + require_lossless: bool = True, + ) -> List[base_pb2.NamedTensor]: """ - Get aggregated tensor from the aggregator. + Get aggregated tensors from the aggregator. Args: - collaborator_name (str): The name of the collaborator. - tensor_name (str): The name of the tensor. - round_number (int): The round number. - report (str): The report. - tags (List[str]): The tags. + tensor_keys (list): A list of tensor keys to fetch from aggregator. require_lossless (bool): Whether lossless compression is required. Returns: - aggregator_pb2.TensorProto: The aggregated tensor. + A list of `NamedTensor`s in the same order as requested. """ header = create_header( sender=self.collaborator_name, @@ -370,17 +363,24 @@ def get_aggregated_tensor( single_col_cert_common_name=self.single_col_cert_common_name, ) - request = aggregator_pb2.GetAggregatedTensorRequest( + request = aggregator_pb2.GetAggregatedTensorsRequest( header=header, - tensor_name=tensor_name, - round_number=round_number, - report=report, - tags=tags, - require_lossless=require_lossless, + tensor_specs=[ + aggregator_pb2.TensorSpec( + tensor_name=k.tensor_name, + round_number=k.round_number, + report=k.report, + tags=k.tags, + require_lossless=require_lossless, + ) + for k in tensor_keys + ], ) - response = self.stub.GetAggregatedTensor(request) + + response = self.stub.GetAggregatedTensors(request) self.validate_response(response) - return response.tensor + named_tensors = response.tensors + return named_tensors @_resend_data_on_reconnection @_atomic_connection diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 0d2f82ac98..ebe89f0ca6 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -212,20 +212,17 @@ def GetTasks(self, request, context): # NOQA:N802 quit=time_to_quit, ) - def GetAggregatedTensor(self, request, context): # NOQA:N802 - """Request a job from aggregator. - - This method handles a request from a collaborator for an aggregated - tensor. + def GetAggregatedTensors(self, request, context): + """Request aggregated tensors from the aggregator. Args: - request (aggregator_pb2.GetAggregatedTensorRequest): The request - from the collaborator. + request (aggregator_pb2.GetAggregatedTensorsRequest): The request + from the collaborator comprising a list of TensorSpec objects. context (grpc.ServicerContext): The context of the request. Returns: - aggregator_pb2.GetAggregatedTensorResponse: The response to the - request. + aggregator_pb2.GetAggregatedTensorsResponse: The response to the + request, containing the aggregated tensors as list of `NamedTensor`s. """ if self.interop_mode: context.abort( @@ -236,13 +233,17 @@ def GetAggregatedTensor(self, request, context): # NOQA:N802 self.validate_collaborator(request, context) self.check_request(request) - named_tensor = self.aggregator.get_aggregated_tensor( - request.tensor_name, - request.round_number, - request.report, - tuple(request.tags), - request.require_lossless, - request.header.sender, + # Parse. + named_tensors = ( + self.aggregator.get_aggregated_tensor( + ts.tensor_name, + ts.round_number, + ts.report, + tuple(ts.tags), + ts.require_lossless, + request.header.sender, + ) + for ts in request.tensor_specs ) header = create_header( @@ -251,19 +252,11 @@ def GetAggregatedTensor(self, request, context): # NOQA:N802 federation_uuid=self.aggregator.federation_uuid, single_col_cert_common_name=self.aggregator.single_col_cert_common_name, ) - - return aggregator_pb2.GetAggregatedTensorResponse( - header=header, - round_number=request.round_number, - tensor=named_tensor, - ) + return aggregator_pb2.GetAggregatedTensorsResponse(header=header, tensors=named_tensors) @synchronized def SendLocalTaskResults(self, request, context): # NOQA:N802 - """Request a model download from aggregator. - - This method handles a request from a collaborator to send the results - of a local task. + """Send task results to the aggregator. Args: request (aggregator_pb2.SendLocalTaskResultsRequest): The request @@ -283,7 +276,6 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802 ) self.validate_collaborator(proto, context) - # all messages get sanity checked self.check_request(proto) collaborator_name = proto.header.sender diff --git a/tests/end_to_end/test_suites/task_runner_tests.py b/tests/end_to_end/test_suites/task_runner_tests.py index 6098e2aee6..83b51f7ca5 100644 --- a/tests/end_to_end/test_suites/task_runner_tests.py +++ b/tests/end_to_end/test_suites/task_runner_tests.py @@ -72,4 +72,4 @@ def test_federation_connectivity(request, fx_federation_tr): # Verify collaborator able to ping aggregator for col in fx_federation_tr.collaborators: - assert fed_helper.ping_from_collaborator(col), f"Ping failed from {col.name} to aggregator" \ No newline at end of file + assert fed_helper.ping_from_collaborator(col), f"Ping failed from {col.name} to aggregator" diff --git a/tests/end_to_end/utils/constants.py b/tests/end_to_end/utils/constants.py index 79f10476ab..c3e63de45f 100644 --- a/tests/end_to_end/utils/constants.py +++ b/tests/end_to_end/utils/constants.py @@ -59,4 +59,4 @@ class ModelName(Enum): COL_CERTIFY_CMD = "fx collaborator certify --import 'agg_to_col_{}_signed_cert.zip'" EXCEPTION = "Exception" AGG_METRIC_MODEL_ACCURACY_KEY = "aggregator/aggregated_model_validation/accuracy" -COL_TLS_END_MSG = "TLS connection established." \ No newline at end of file +COL_TLS_END_MSG = "TLS connection established." diff --git a/tests/end_to_end/utils/federation_helper.py b/tests/end_to_end/utils/federation_helper.py index 27cdf9a326..99d1a6c6b5 100644 --- a/tests/end_to_end/utils/federation_helper.py +++ b/tests/end_to_end/utils/federation_helper.py @@ -1202,4 +1202,4 @@ def ping_from_collaborator(collaborator): log.info(f"Aggregator is not reachable from {collaborator.name}. Retrying in 5 seconds...") time.sleep(5) log.error(f"Aggregator is not reachable from {collaborator.name}") - return False \ No newline at end of file + return False diff --git a/tests/openfl/component/collaborator/test_collaborator.py b/tests/openfl/component/collaborator/test_collaborator.py index 49288c2ffe..c5b2a22d0b 100644 --- a/tests/openfl/component/collaborator/test_collaborator.py +++ b/tests/openfl/component/collaborator/test_collaborator.py @@ -97,15 +97,13 @@ def test_send_task_results(collaborator_mock, tensor_key): """Test that send_task_results works correctly.""" task_name = 'task_name' tensor_key = tensor_key._replace(report=True) - tensor_dict = {tensor_key: 0} + tensor_dict = {} round_number = 0 data_size = -1 - collaborator_mock.nparray_to_named_tensor = mock.Mock(return_value=None) collaborator_mock.client.send_local_task_results = mock.Mock() collaborator_mock.send_task_results(tensor_dict, round_number, task_name) - collaborator_mock.client.send_local_task_results.assert_called_with( - round_number, task_name, data_size, [None]) + round_number, task_name, data_size, []) def test_send_task_results_train(collaborator_mock): @@ -138,95 +136,15 @@ def test_send_task_results_valid(collaborator_mock): round_number, task_name, data_size, []) -def test_named_tensor_to_nparray_without_tags(collaborator_mock, named_tensor): - """Test that named_tensor_to_nparray works correctly for tensor without tags.""" - nparray = collaborator_mock.named_tensor_to_nparray(named_tensor) - - assert named_tensor.data_bytes == nparray - - -@pytest.mark.parametrize('tag', ['compressed', 'lossy_compressed']) -def test_named_tensor_to_nparray_compressed_tag(collaborator_mock, named_tensor, tag): - """Test that named_tensor_to_nparray works correctly for tensor with tags.""" - named_tensor.tags.append(tag) - nparray = collaborator_mock.named_tensor_to_nparray(named_tensor) - - assert isinstance(nparray, numpy.ndarray) - - -def test_nparray_to_named_tensor(collaborator_mock, tensor_key, named_tensor): - """Test that nparray_to_named_tensor works correctly.""" - named_tensor.tags.append('compressed') - nparray = collaborator_mock.named_tensor_to_nparray(named_tensor) - tensor = collaborator_mock.nparray_to_named_tensor(tensor_key, nparray) - assert tensor.data_bytes == named_tensor.data_bytes - assert tensor.lossless is True - - -def test_nparray_to_named_tensor_trained(collaborator_mock, tensor_key_trained, named_tensor): - """Test that nparray_to_named_tensor works correctly for trained tensor.""" - named_tensor.tags.append('compressed') - collaborator_mock.use_delta_updates = True - nparray = collaborator_mock.named_tensor_to_nparray(named_tensor) - collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock( - return_value=nparray) - tensor = collaborator_mock.nparray_to_named_tensor(tensor_key_trained, nparray) - assert len(tensor.data_bytes) == 32 - assert tensor.lossless is False - assert 'delta' in tensor.tags - - -@pytest.mark.parametrize('require_lossless', [True, False]) -def test_get_aggregated_tensor_from_aggregator(collaborator_mock, tensor_key, - named_tensor, require_lossless): - """Test that get_aggregated_tensor works correctly.""" - collaborator_mock.client.get_aggregated_tensor = mock.Mock(return_value=named_tensor) - nparray = collaborator_mock.get_aggregated_tensor_from_aggregator(tensor_key, require_lossless) - - collaborator_mock.client.get_aggregated_tensor.assert_called_with( - tensor_key.tensor_name, tensor_key.round_number, - tensor_key.report, tensor_key.tags, require_lossless) - assert nparray == named_tensor.data_bytes - - -def test_get_data_for_tensorkey_from_db(collaborator_mock, tensor_key): - """Test that get_data_for_tensorkey works correctly for data form db.""" - expected_nparray = 'some_data' - collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock( - return_value='some_data') - nparray = collaborator_mock.get_data_for_tensorkey(tensor_key) - - assert nparray == expected_nparray - - -def test_get_data_for_tensorkey(collaborator_mock, tensor_key): - """Test that get_data_for_tensorkey works correctly if data is not in db.""" - collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock( - return_value=None) - collaborator_mock.get_aggregated_tensor_from_aggregator = mock.Mock() - collaborator_mock.get_data_for_tensorkey(tensor_key) - collaborator_mock.get_aggregated_tensor_from_aggregator.assert_called_with( - tensor_key, require_lossless=True) - - -def test_get_data_for_tensorkey_locally(collaborator_mock, tensor_key): - """Test that get_data_for_tensorkey works correctly if found tensor locally.""" - tensor_key = tensor_key._replace(round_number=1) - nparray = numpy.array([0, 1, 2, 3, 4]) - collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock( - side_effect=[None, nparray]) - ret = collaborator_mock.get_data_for_tensorkey(tensor_key) - - assert numpy.array_equal(ret, nparray) - - -def test_get_data_for_tensorkey_dependencies(collaborator_mock, tensor_key): - """Test that get_data_for_tensorkey works correctly if additional dependencies.""" - tensor_key = tensor_key._replace(round_number=1) - collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock( - return_value=None) - collaborator_mock.tensor_codec.find_dependencies = mock.Mock(return_value=[tensor_key]) - collaborator_mock.get_aggregated_tensor_from_aggregator = mock.Mock() - collaborator_mock.get_data_for_tensorkey(tensor_key) - collaborator_mock.get_aggregated_tensor_from_aggregator.assert_called_with( - tensor_key, require_lossless=True) +def test_fetch_tensors_from_aggregator(collaborator_mock, tensor_key, named_tensor): + """Test that fetch_tensors_from_aggregator works correctly.""" + # Simulate tensor not in cache + collaborator_mock.tensor_db.get_tensor_from_cache.return_value = None + collaborator_mock.client.get_aggregated_tensors = mock.Mock(return_value=[named_tensor]) + collaborator_mock.tensor_db.cache_tensor = mock.Mock() + # Patch utils.deserialize_tensor to avoid side effects + with mock.patch("openfl.protocols.utils.deserialize_tensor", return_value=(tensor_key, "nparray")): + collaborator_mock.fetch_tensors_from_aggregator([tensor_key]) + collaborator_mock.client.get_aggregated_tensors.assert_called_with( + [tensor_key], require_lossless=True) + collaborator_mock.tensor_db.cache_tensor.assert_called()