Skip to content

Commit 19c29e2

Browse files
Simplify aggregator with no deltas and simple compress/decompress
Signed-off-by: Shah, Karan <[email protected]>
1 parent c915b16 commit 19c29e2

File tree

1 file changed

+43
-210
lines changed

1 file changed

+43
-210
lines changed

openfl/component/aggregator/aggregator.py

Lines changed: 43 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -606,65 +606,46 @@ def get_aggregated_tensor(
606606
Raises:
607607
ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}.
608608
"""
609-
if "compressed" in tags or require_lossless:
610-
compress_lossless = True
611-
else:
612-
compress_lossless = False
613-
614609
if not self._check_tags(tags, requested_by):
615610
logger.error(
616-
"Tag check failed: unauthorized tags detected. Only '%s' is allowed.", requested_by
611+
"Collaborator `%s` is not allowed to fetch tensor with tags `%s`.",
612+
requested_by,
613+
tags,
617614
)
618615
return NamedTensor()
619616

620-
# TODO the TensorDB doesn't support compressed data yet.
621-
# The returned tensor will be recompressed anyway.
617+
# We simply remove compression-related tags because serializer adds them.
622618
if "compressed" in tags:
623619
tags = change_tags(tags, remove_field="compressed")
624620
if "lossy_compressed" in tags:
625621
tags = change_tags(tags, remove_field="lossy_compressed")
626622

623+
# Fetch tensor
627624
tensor_key = TensorKey(tensor_name, self.uuid, round_number, report, tags)
628-
tensor_name, origin, round_number, report, tags = tensor_key
629-
630-
if "aggregated" in tags and "delta" in tags and round_number != 0:
631-
agg_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("aggregated",))
632-
else:
633-
agg_tensor_key = tensor_key
634-
635-
nparray = self.tensor_db.get_tensor_from_cache(agg_tensor_key)
625+
nparray = self.tensor_db.get_tensor_from_cache(tensor_key)
636626
if nparray is None:
637627
raise ValueError(f"Aggregator does not have `{tensor_key}`")
638628

639-
# quite a bit happens in here, including compression, delta handling,
640-
# etc...
641-
# we might want to cache these as well
642-
named_tensor = self._nparray_to_named_tensor(
643-
agg_tensor_key, nparray, send_model_deltas=True, compress_lossless=compress_lossless
644-
)
645-
629+
# Serialize (and compress) the tensor
630+
named_tensor = self.serialize_tensor(tensor_key, nparray, lossless=require_lossless)
646631
return named_tensor
647632

648-
def _nparray_to_named_tensor(self, tensor_key, nparray, send_model_deltas, compress_lossless):
649-
"""Construct the NamedTensor Protobuf.
633+
def serialize_tensor(self, tensor_key, nparray, lossless: bool):
634+
"""Serialize the tensor.
650635
651-
Also includes logic to create delta, compress tensors with the
652-
TensorCodec, etc.
636+
This function also performs compression.
653637
654638
Args:
655-
tensor_key (TensorKey): Tensor key.
656-
nparray (np.array): Numpy array.
657-
send_model_deltas (bool): Whether to send model deltas.
658-
compress_lossless (bool): Whether to compress lossless.
639+
tensor_key (namedtuple): A TensorKey.
640+
nparray: A NumPy array associated with the requested
641+
tensor key.
642+
lossless: Whether to use lossless compression.
659643
660644
Returns:
661-
tensor_key (TensorKey): Tensor key.
662-
nparray (np.array): Numpy array.
663-
645+
named_tensor (protobuf) : The tensor constructed from the nparray.
664646
"""
665-
tensor_name, origin, round_number, report, tags = tensor_key
666647
# Secure aggregation setup tensor.
667-
if "secagg" in tags:
648+
if "secagg" in tensor_key.tags:
668649
import numpy as np
669650

670651
class NumpyEncoder(json.JSONEncoder):
@@ -681,42 +662,14 @@ def default(self, obj):
681662
)
682663

683664
return named_tensor
684-
# if we have an aggregated tensor, we can make a delta
685-
if "aggregated" in tags and send_model_deltas:
686-
# Should get the pretrained model to create the delta. If training
687-
# has happened, Model should already be stored in the TensorDB
688-
model_tk = TensorKey(tensor_name, origin, round_number - 1, report, ("model",))
689-
690-
model_nparray = self.tensor_db.get_tensor_from_cache(model_tk)
691665

692-
assert model_nparray is not None, (
693-
"The original model layer should be present if the latest "
694-
"aggregated model is present"
695-
)
696-
delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta(
697-
tensor_key, nparray, model_nparray
698-
)
699-
delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress(
700-
delta_tensor_key, delta_nparray, lossless=compress_lossless
701-
)
702-
named_tensor = utils.construct_named_tensor(
703-
delta_comp_tensor_key,
704-
delta_comp_nparray,
705-
metadata,
706-
lossless=compress_lossless,
707-
)
708-
709-
else:
710-
# Assume every other tensor requires lossless compression
711-
compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress(
712-
tensor_key, nparray, require_lossless=True
713-
)
714-
named_tensor = utils.construct_named_tensor(
715-
compressed_tensor_key,
716-
compressed_nparray,
717-
metadata,
718-
lossless=compress_lossless,
719-
)
666+
tensor_key, nparray, metadata = self.tensor_codec.compress(tensor_key, nparray, lossless)
667+
named_tensor = utils.construct_named_tensor(
668+
tensor_key,
669+
nparray,
670+
metadata,
671+
lossless,
672+
)
720673

721674
return named_tensor
722675

@@ -837,9 +790,7 @@ def process_task_results(
837790
task_results = []
838791

839792
for named_tensor in named_tensors:
840-
# quite a bit happens in here, including decompression, delta
841-
# handling, etc...
842-
tensor_key, value = self._process_named_tensor(named_tensor, collaborator_name)
793+
tensor_key, value = self.deserialize_tensor(named_tensor, collaborator_name)
843794

844795
if "metric" in tensor_key.tags:
845796
# Caution: This schema must be followed. It is also used in
@@ -881,26 +832,17 @@ def _end_of_round_with_stragglers_check(self):
881832
logger.warning(f"Identified stragglers: {self.stragglers}")
882833
self._end_of_round_check()
883834

884-
def _process_named_tensor(self, named_tensor, collaborator_name):
885-
"""Extract the named tensor fields.
835+
def deserialize_tensor(self, named_tensor, collaborator_name):
836+
"""Deserialize a `NamedTensor` to a numpy array.
886837
887-
Performs decompression, delta computation, and inserts results into
888-
TensorDB.
838+
This function also performs decompresssion.
889839
890840
Args:
891-
named_tensor (protobuf NamedTensor): Named tensor.
892-
protobuf that will be extracted from and processed
893-
collaborator_name (str): Collaborator name.
894-
Collaborator name is needed for proper tagging of resulting
895-
tensorkeys.
841+
named_tensor (protobuf): The tensor to convert to nparray.
896842
897843
Returns:
898-
tensor_key (TensorKey): Tensor key.
899-
The tensorkey extracted from the protobuf.
900-
nparray (np.array): Numpy array.
901-
The numpy array associated with the returned tensorkey.
844+
A tuple (TensorKey, nparray).
902845
"""
903-
raw_bytes = named_tensor.data_bytes
904846
metadata = [
905847
{
906848
"int_to_float": proto.int_to_float,
@@ -918,59 +860,19 @@ def _process_named_tensor(self, named_tensor, collaborator_name):
918860
named_tensor.report,
919861
tuple(named_tensor.tags),
920862
)
921-
tensor_name, origin, round_number, report, tags = tensor_key
922863

923-
assert "compressed" in tags or "lossy_compressed" in tags, (
924-
f"Named tensor {tensor_key} is not compressed"
864+
tensor_key, nparray = self.tensor_codec.decompress(
865+
tensor_key,
866+
data=named_tensor.data_bytes,
867+
transformer_metadata=metadata,
868+
require_lossless=named_tensor.lossless,
925869
)
926-
if "compressed" in tags:
927-
dec_tk, decompressed_nparray = self.tensor_codec.decompress(
928-
tensor_key,
929-
data=raw_bytes,
930-
transformer_metadata=metadata,
931-
require_lossless=True,
932-
)
933-
dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk
934-
# Need to add the collaborator tag to the resulting tensor
935-
new_tags = change_tags(dec_tags, add_field=collaborator_name)
936-
937-
# layer.agg.n.trained.delta.col_i
938-
decompressed_tensor_key = TensorKey(
939-
dec_name, dec_origin, dec_round_num, dec_report, new_tags
940-
)
941-
if "lossy_compressed" in tags:
942-
dec_tk, decompressed_nparray = self.tensor_codec.decompress(
943-
tensor_key,
944-
data=raw_bytes,
945-
transformer_metadata=metadata,
946-
require_lossless=False,
947-
)
948-
dec_name, dec_origin, dec_round_num, dec_report, dec_tags = dec_tk
949-
new_tags = change_tags(dec_tags, add_field=collaborator_name)
950-
# layer.agg.n.trained.delta.lossy_decompressed.col_i
951-
decompressed_tensor_key = TensorKey(
952-
dec_name, dec_origin, dec_round_num, dec_report, new_tags
953-
)
954-
955-
if "delta" in tags:
956-
base_model_tensor_key = TensorKey(tensor_name, origin, round_number, report, ("model",))
957-
base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tensor_key)
958-
if base_model_nparray is None:
959-
raise ValueError(f"Base model {base_model_tensor_key} not present in TensorDB")
960-
final_tensor_key, final_nparray = self.tensor_codec.apply_delta(
961-
decompressed_tensor_key,
962-
decompressed_nparray,
963-
base_model_nparray,
964-
)
965-
else:
966-
final_tensor_key = decompressed_tensor_key
967-
final_nparray = decompressed_nparray
870+
updated_tags = change_tags(tensor_key.tags, add_field=collaborator_name)
871+
tensor_key = tensor_key._replace(tags=updated_tags)
968872

969-
assert final_nparray is not None, f"Could not create tensorkey {final_tensor_key}"
970-
self.tensor_db.cache_tensor({final_tensor_key: final_nparray})
971-
logger.debug("Created TensorKey: %s", final_tensor_key)
873+
self.tensor_db.cache_tensor({tensor_key: nparray})
972874

973-
return final_tensor_key, final_nparray
875+
return tensor_key, nparray
974876

975877
def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results):
976878
"""Prepare aggregated tensorkey tags.
@@ -982,82 +884,13 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
982884
report (bool): Whether to report.
983885
agg_results (np.array): Aggregated results.
984886
"""
985-
# The aggregated tensorkey tags should have the form of
986-
# 'trained' or 'trained.lossy_decompressed'
987-
# They need to be relabeled to 'aggregated' and
988-
# reinserted. Then delta performed, compressed, etc.
989-
# then reinserted to TensorDB with 'model' tag
990-
991-
# First insert the aggregated model layer with the
992-
# correct tensorkey
993887
agg_tag_tk = TensorKey(tensor_name, origin, round_number + 1, report, ("aggregated",))
994888
self.tensor_db.cache_tensor({agg_tag_tk: agg_results})
995889

996-
# Create delta and save it in TensorDB
997-
base_model_tk = TensorKey(tensor_name, origin, round_number, report, ("model",))
998-
base_model_nparray = self.tensor_db.get_tensor_from_cache(base_model_tk)
999-
if base_model_nparray is not None and self.use_delta_updates:
1000-
delta_tk, delta_nparray = self.tensor_codec.generate_delta(
1001-
agg_tag_tk, agg_results, base_model_nparray
1002-
)
1003-
else:
1004-
# This condition is possible for base model
1005-
# optimizer states (i.e. Adam/iter:0, SGD, etc.)
1006-
# These values couldn't be present for the base
1007-
# model because no training occurs on the aggregator
1008-
delta_tk, delta_nparray = agg_tag_tk, agg_results
1009-
1010-
# Compress lossless/lossy
1011-
compressed_delta_tk, compressed_delta_nparray, metadata = self.tensor_codec.compress(
1012-
delta_tk, delta_nparray
1013-
)
1014-
1015-
# TODO extend the TensorDB so that compressed data is
1016-
# supported. Once that is in place
1017-
# the compressed delta can just be stored here instead
1018-
# of recreating it for every request
1019-
1020-
# Decompress lossless/lossy
1021-
decompressed_delta_tk, decompressed_delta_nparray = self.tensor_codec.decompress(
1022-
compressed_delta_tk, compressed_delta_nparray, metadata
1023-
)
1024-
1025-
self.tensor_db.cache_tensor({decompressed_delta_tk: decompressed_delta_nparray})
1026-
1027-
# Apply delta (unless delta couldn't be created)
1028-
if base_model_nparray is not None and self.use_delta_updates:
1029-
logger.debug("Applying delta for layer %s", decompressed_delta_tk[0])
1030-
new_model_tk, new_model_nparray = self.tensor_codec.apply_delta(
1031-
decompressed_delta_tk,
1032-
decompressed_delta_nparray,
1033-
base_model_nparray,
1034-
)
1035-
else:
1036-
new_model_tk, new_model_nparray = (
1037-
decompressed_delta_tk,
1038-
decompressed_delta_nparray,
1039-
)
1040-
1041-
# Now that the model has been compressed/decompressed
1042-
# with delta operations,
1043-
# Relabel the tags to 'model'
1044-
(
1045-
new_model_tensor_name,
1046-
new_model_origin,
1047-
new_model_round_number,
1048-
new_model_report,
1049-
new_model_tags,
1050-
) = new_model_tk
1051-
final_model_tk = TensorKey(
1052-
new_model_tensor_name,
1053-
new_model_origin,
1054-
new_model_round_number,
1055-
new_model_report,
1056-
("model",),
1057-
)
1058-
self.next_model_round_number = new_model_round_number
1059-
# Finally, cache the updated model tensor
1060-
self.tensor_db.cache_tensor({final_model_tk: new_model_nparray})
890+
# Relabel the tags to 'model' and cache the updated model tensor
891+
final_model_tk = agg_tag_tk._replace(tags=("model",))
892+
self.next_model_round_number = final_model_tk.round_number
893+
self.tensor_db.cache_tensor({final_model_tk: agg_results})
1061894

1062895
def _compute_validation_related_task_metrics(self, task_name) -> dict:
1063896
"""Compute all validation related metrics.

0 commit comments

Comments
 (0)