Skip to content

Commit 3c528ab

Browse files
Use common serialize/deserialize functions for both components
Signed-off-by: Shah, Karan <[email protected]>
1 parent 9c19cc3 commit 3c528ab

File tree

3 files changed

+107
-152
lines changed

3 files changed

+107
-152
lines changed

openfl/component/aggregator/aggregator.py

Lines changed: 24 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -627,24 +627,6 @@ def get_aggregated_tensor(
627627
raise ValueError(f"Aggregator does not have `{tensor_key}`")
628628

629629
# Serialize (and compress) the tensor
630-
named_tensor = self.serialize_tensor(tensor_key, nparray, lossless=require_lossless)
631-
return named_tensor
632-
633-
def serialize_tensor(self, tensor_key, nparray, lossless: bool):
634-
"""Serialize the tensor.
635-
636-
This function also performs compression.
637-
638-
Args:
639-
tensor_key (namedtuple): A TensorKey.
640-
nparray: A NumPy array associated with the requested
641-
tensor key.
642-
lossless: Whether to use lossless compression.
643-
644-
Returns:
645-
named_tensor (protobuf) : The tensor constructed from the nparray.
646-
"""
647-
# Secure aggregation setup tensor.
648630
if "secagg" in tensor_key.tags:
649631
import numpy as np
650632

@@ -663,14 +645,9 @@ def default(self, obj):
663645

664646
return named_tensor
665647

666-
tensor_key, nparray, metadata = self.tensor_codec.compress(tensor_key, nparray, lossless)
667-
named_tensor = utils.construct_named_tensor(
668-
tensor_key,
669-
nparray,
670-
metadata,
671-
lossless,
648+
named_tensor = utils.serialize_tensor(
649+
tensor_key, nparray, self.tensor_codec, lossless=require_lossless
672650
)
673-
674651
return named_tensor
675652

676653
def _collaborator_task_completed(self, collaborator, task_name, round_num):
@@ -769,43 +746,47 @@ def process_task_results(
769746
self._is_collaborator_done(collaborator_name, round_number)
770747
self._end_of_round_with_stragglers_check()
771748

772-
task_key = TaskResultKey(task_name, collaborator_name, round_number)
773-
774-
# we mustn't have results already
775749
if self._collaborator_task_completed(collaborator_name, task_name, round_number):
776750
logger.warning(
777-
f"Aggregator already has task results from collaborator {collaborator_name}"
778-
f" for task {task_key}"
751+
f"Aggregator already has task results from collaborator {collaborator_name} "
752+
f"for task {task_name} in round {round_number}. Ignoring..."
779753
)
780754
return
781755

782-
# By giving task_key it's own weight, we can support different
783-
# training/validation weights
784-
# As well as eventually supporting weights that change by round
785-
# (if more data is added)
756+
# Record collaborator individual weightage/contribution for federated averaging
757+
task_key = TaskResultKey(task_name, collaborator_name, round_number)
786758
self.collaborator_task_weight[task_key] = data_size
787759

788-
# initialize the list of tensors that go with this task
789-
# Setting these incrementally is leading to missing values
760+
# Process named tensors
790761
task_results = []
791-
762+
result_tensor_dict = {}
792763
for named_tensor in named_tensors:
793-
tensor_key, value = self.deserialize_tensor(named_tensor, collaborator_name)
764+
# Deserialize
765+
tensor_key, nparray = utils.deserialize_tensor(named_tensor, self.tensor_codec)
766+
767+
# Update origin/tags
768+
updated_tags = change_tags(tensor_key.tags, add_field=collaborator_name)
769+
tensor_key = tensor_key._replace(origin=self.uuid, tags=updated_tags)
770+
771+
# Record
772+
result_tensor_dict[tensor_key] = nparray
773+
task_results.append(tensor_key)
794774

795775
if "metric" in tensor_key.tags:
796-
# Caution: This schema must be followed. It is also used in
797-
# gRPC message streams for director/envoy.
776+
assert nparray.ndim == 0, (
777+
f"Expected metric to be a scalar, got shape {nparray.shape}"
778+
)
798779
metrics = {
799780
"round": round_number,
800781
"metric_origin": collaborator_name,
801782
"task_name": task_name,
802783
"metric_name": tensor_key.tensor_name,
803-
"metric_value": float(value),
784+
"metric_value": float(nparray),
804785
}
805786
self.metric_queue.put(metrics)
806787

807-
task_results.append(tensor_key)
808-
788+
# Store results in TensorDB
789+
self.tensor_db.cache_tensor(result_tensor_dict)
809790
self.collaborator_tasks_results[task_key] = task_results
810791

811792
# Check if collaborator or round is done.
@@ -832,48 +813,6 @@ def _end_of_round_with_stragglers_check(self):
832813
logger.warning(f"Identified stragglers: {self.stragglers}")
833814
self._end_of_round_check()
834815

835-
def deserialize_tensor(self, named_tensor, collaborator_name):
836-
"""Deserialize a `NamedTensor` to a numpy array.
837-
838-
This function also performs decompresssion.
839-
840-
Args:
841-
named_tensor (protobuf): The tensor to convert to nparray.
842-
843-
Returns:
844-
A tuple (TensorKey, nparray).
845-
"""
846-
metadata = [
847-
{
848-
"int_to_float": proto.int_to_float,
849-
"int_list": proto.int_list,
850-
"bool_list": proto.bool_list,
851-
}
852-
for proto in named_tensor.transformer_metadata
853-
]
854-
# The tensor has already been transferred to aggregator,
855-
# so the newly constructed tensor should have the aggregator origin
856-
tensor_key = TensorKey(
857-
named_tensor.name,
858-
self.uuid,
859-
named_tensor.round_number,
860-
named_tensor.report,
861-
tuple(named_tensor.tags),
862-
)
863-
864-
tensor_key, nparray = self.tensor_codec.decompress(
865-
tensor_key,
866-
data=named_tensor.data_bytes,
867-
transformer_metadata=metadata,
868-
require_lossless=named_tensor.lossless,
869-
)
870-
updated_tags = change_tags(tensor_key.tags, add_field=collaborator_name)
871-
tensor_key = tensor_key._replace(tags=updated_tags)
872-
873-
self.tensor_db.cache_tensor({tensor_key: nparray})
874-
875-
return tensor_key, nparray
876-
877816
def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results):
878817
"""Prepare aggregated tensorkey tags.
879818

openfl/component/collaborator/collaborator.py

Lines changed: 9 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,12 @@ def fetch_tensors_from_aggregator(self, tensor_keys: List[TensorKey]):
283283
if len(tensor_keys) > 0:
284284
logger.info("Fetching %d tensors from the aggregator", len(tensor_keys))
285285
named_tensors = self.client.get_aggregated_tensors(tensor_keys, require_lossless=True)
286+
287+
# Deserialize tensors and mark them as coming from the aggregator.
286288
for tensor_key, named_tensor in zip(tensor_keys, named_tensors):
287-
tensor_dict[tensor_key] = self.deserialize_tensor(named_tensor)
289+
tensor_key, nparray = utils.deserialize_tensor(named_tensor, self.tensor_codec)
290+
tensor_key = tensor_key._replace(origin=self.aggregator_uuid)
291+
tensor_dict[tensor_key] = nparray
288292

289293
self.tensor_db.cache_tensor(tensor_dict)
290294

@@ -322,7 +326,10 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
322326
metrics.update({f"{self.collaborator_name}/{task_name}/{tensor_name}": value})
323327

324328
# Serialize tensors to be sent to the aggregator
325-
named_tensors = [self.serialize_tensor(k, v) for k, v in tensor_dict.items()]
329+
named_tensors = [
330+
utils.serialize_tensor(k, v, self.tensor_codec, lossless=True)
331+
for k, v in tensor_dict.items()
332+
]
326333

327334
self.client.send_local_task_results(
328335
round_number,
@@ -333,71 +340,6 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
333340

334341
return metrics
335342

336-
def serialize_tensor(self, tensor_key, nparray):
337-
"""Serialize the tensor.
338-
339-
This function also performs compression.
340-
341-
Args:
342-
tensor_key (namedtuple): A TensorKey.
343-
nparray: A NumPy array associated with the requested
344-
tensor key.
345-
346-
Returns:
347-
named_tensor (protobuf) : The tensor constructed from the nparray.
348-
"""
349-
lossless = True
350-
tensor_key, nparray, metadata = self.tensor_codec.compress(
351-
tensor_key,
352-
nparray,
353-
lossless,
354-
)
355-
named_tensor = utils.construct_named_tensor(
356-
tensor_key,
357-
nparray,
358-
metadata,
359-
lossless,
360-
)
361-
return named_tensor
362-
363-
def deserialize_tensor(self, named_tensor):
364-
"""Deserialize a `NamedTensor` to a numpy array.
365-
366-
This function also performs decompresssion.
367-
368-
Args:
369-
named_tensor (protobuf): The tensor to convert to nparray.
370-
371-
Returns:
372-
The converted nparray.
373-
"""
374-
metadata = [
375-
{
376-
"int_to_float": proto.int_to_float,
377-
"int_list": proto.int_list,
378-
"bool_list": proto.bool_list,
379-
}
380-
for proto in named_tensor.transformer_metadata
381-
]
382-
# The tensor has already been transferred to collaborator, so
383-
# the newly constructed tensor should have the collaborator origin
384-
tensor_key = TensorKey(
385-
named_tensor.name,
386-
self.collaborator_name,
387-
named_tensor.round_number,
388-
named_tensor.report,
389-
tuple(named_tensor.tags),
390-
)
391-
392-
tensor_key, nparray = self.tensor_codec.decompress(
393-
tensor_key,
394-
data=named_tensor.data_bytes,
395-
transformer_metadata=metadata,
396-
require_lossless=named_tensor.lossless,
397-
)
398-
399-
return nparray
400-
401343
def _apply_masks(
402344
self,
403345
tensor_dict,

openfl/protocols/utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,77 @@ def get_headers(context) -> dict:
352352
values are the corresponding header values.
353353
"""
354354
return {header[0]: header[1] for header in context.invocation_metadata()}
355+
356+
357+
def serialize_tensor(tensor_key, nparray, tensor_codec, lossless=True):
358+
"""Serialize the tensor.
359+
360+
This function also performs compression.
361+
362+
Args:
363+
tensor_key (namedtuple): A TensorKey.
364+
nparray: A NumPy array associated with the requested
365+
tensor key.
366+
tensor_codec: The codec to use for compression.
367+
lossless: A flag indicating whether to use lossless compression.
368+
369+
Returns:
370+
named_tensor (protobuf) : The tensor constructed from the nparray.
371+
"""
372+
tensor_key, nparray, metadata = tensor_codec.compress(
373+
tensor_key,
374+
nparray,
375+
lossless,
376+
)
377+
named_tensor = construct_named_tensor(
378+
tensor_key,
379+
nparray,
380+
metadata,
381+
lossless,
382+
)
383+
return named_tensor
384+
385+
386+
def deserialize_tensor(named_tensor, tensor_codec):
387+
"""Deserialize a `NamedTensor` to a numpy array.
388+
389+
This function also performs decompresssion. Whether or not the
390+
decompression is lossless is determined by the `lossless` field
391+
of the `NamedTensor` protobuf.
392+
393+
Args:
394+
named_tensor (protobuf): The tensor to convert to nparray.
395+
tensor_codec: The codec to use for decompression.
396+
397+
Returns:
398+
A tuple (TensorKey, nparray), where the `origin` field of the
399+
`TensorKey` is `None`. The `origin` field must be populated
400+
later, as it is not known at this point.
401+
"""
402+
metadata = [
403+
{
404+
"int_to_float": proto.int_to_float,
405+
"int_list": proto.int_list,
406+
"bool_list": proto.bool_list,
407+
}
408+
for proto in named_tensor.transformer_metadata
409+
]
410+
411+
# Deserialization happens on the receiving end.
412+
# Origin of this tensor is populated later.
413+
tensor_key = TensorKey(
414+
named_tensor.name,
415+
None,
416+
named_tensor.round_number,
417+
named_tensor.report,
418+
tuple(named_tensor.tags),
419+
)
420+
421+
tensor_key, nparray = tensor_codec.decompress(
422+
tensor_key,
423+
data=named_tensor.data_bytes,
424+
transformer_metadata=metadata,
425+
require_lossless=named_tensor.lossless,
426+
)
427+
428+
return tensor_key, nparray

0 commit comments

Comments
 (0)