@@ -606,65 +606,46 @@ def get_aggregated_tensor(
606
606
Raises:
607
607
ValueError: if Aggregator does not have an aggregated tensor for {tensor_key}.
608
608
"""
609
- if "compressed" in tags or require_lossless :
610
- compress_lossless = True
611
- else :
612
- compress_lossless = False
613
-
614
609
if not self ._check_tags (tags , requested_by ):
615
610
logger .error (
616
- "Tag check failed: unauthorized tags detected. Only '%s' is allowed." , requested_by
611
+ "Collaborator `%s` is not allowed to fetch tensor with tags `%s`." ,
612
+ requested_by ,
613
+ tags ,
617
614
)
618
615
return NamedTensor ()
619
616
620
- # TODO the TensorDB doesn't support compressed data yet.
621
- # The returned tensor will be recompressed anyway.
617
+ # We simply remove compression-related tags because serializer adds them.
622
618
if "compressed" in tags :
623
619
tags = change_tags (tags , remove_field = "compressed" )
624
620
if "lossy_compressed" in tags :
625
621
tags = change_tags (tags , remove_field = "lossy_compressed" )
626
622
623
+ # Fetch tensor
627
624
tensor_key = TensorKey (tensor_name , self .uuid , round_number , report , tags )
628
- tensor_name , origin , round_number , report , tags = tensor_key
629
-
630
- if "aggregated" in tags and "delta" in tags and round_number != 0 :
631
- agg_tensor_key = TensorKey (tensor_name , origin , round_number , report , ("aggregated" ,))
632
- else :
633
- agg_tensor_key = tensor_key
634
-
635
- nparray = self .tensor_db .get_tensor_from_cache (agg_tensor_key )
625
+ nparray = self .tensor_db .get_tensor_from_cache (tensor_key )
636
626
if nparray is None :
637
627
raise ValueError (f"Aggregator does not have `{ tensor_key } `" )
638
628
639
- # quite a bit happens in here, including compression, delta handling,
640
- # etc...
641
- # we might want to cache these as well
642
- named_tensor = self ._nparray_to_named_tensor (
643
- agg_tensor_key , nparray , send_model_deltas = True , compress_lossless = compress_lossless
644
- )
645
-
629
+ # Serialize (and compress) the tensor
630
+ named_tensor = self .serialize_tensor (tensor_key , nparray , lossless = require_lossless )
646
631
return named_tensor
647
632
648
- def _nparray_to_named_tensor (self , tensor_key , nparray , send_model_deltas , compress_lossless ):
649
- """Construct the NamedTensor Protobuf .
633
+ def serialize_tensor (self , tensor_key , nparray , lossless : bool ):
634
+ """Serialize the tensor .
650
635
651
- Also includes logic to create delta, compress tensors with the
652
- TensorCodec, etc.
636
+ This function also performs compression.
653
637
654
638
Args:
655
- tensor_key (TensorKey ): Tensor key .
656
- nparray (np.array): Numpy array.
657
- send_model_deltas (bool): Whether to send model deltas .
658
- compress_lossless (bool) : Whether to compress lossless.
639
+ tensor_key (namedtuple ): A TensorKey .
640
+ nparray: A NumPy array associated with the requested
641
+ tensor key .
642
+ lossless : Whether to use lossless compression .
659
643
660
644
Returns:
661
- tensor_key (TensorKey): Tensor key.
662
- nparray (np.array): Numpy array.
663
-
645
+ named_tensor (protobuf) : The tensor constructed from the nparray.
664
646
"""
665
- tensor_name , origin , round_number , report , tags = tensor_key
666
647
# Secure aggregation setup tensor.
667
- if "secagg" in tags :
648
+ if "secagg" in tensor_key . tags :
668
649
import numpy as np
669
650
670
651
class NumpyEncoder (json .JSONEncoder ):
@@ -681,42 +662,14 @@ def default(self, obj):
681
662
)
682
663
683
664
return named_tensor
684
- # if we have an aggregated tensor, we can make a delta
685
- if "aggregated" in tags and send_model_deltas :
686
- # Should get the pretrained model to create the delta. If training
687
- # has happened, Model should already be stored in the TensorDB
688
- model_tk = TensorKey (tensor_name , origin , round_number - 1 , report , ("model" ,))
689
-
690
- model_nparray = self .tensor_db .get_tensor_from_cache (model_tk )
691
665
692
- assert model_nparray is not None , (
693
- "The original model layer should be present if the latest "
694
- "aggregated model is present"
695
- )
696
- delta_tensor_key , delta_nparray = self .tensor_codec .generate_delta (
697
- tensor_key , nparray , model_nparray
698
- )
699
- delta_comp_tensor_key , delta_comp_nparray , metadata = self .tensor_codec .compress (
700
- delta_tensor_key , delta_nparray , lossless = compress_lossless
701
- )
702
- named_tensor = utils .construct_named_tensor (
703
- delta_comp_tensor_key ,
704
- delta_comp_nparray ,
705
- metadata ,
706
- lossless = compress_lossless ,
707
- )
708
-
709
- else :
710
- # Assume every other tensor requires lossless compression
711
- compressed_tensor_key , compressed_nparray , metadata = self .tensor_codec .compress (
712
- tensor_key , nparray , require_lossless = True
713
- )
714
- named_tensor = utils .construct_named_tensor (
715
- compressed_tensor_key ,
716
- compressed_nparray ,
717
- metadata ,
718
- lossless = compress_lossless ,
719
- )
666
+ tensor_key , nparray , metadata = self .tensor_codec .compress (tensor_key , nparray , lossless )
667
+ named_tensor = utils .construct_named_tensor (
668
+ tensor_key ,
669
+ nparray ,
670
+ metadata ,
671
+ lossless ,
672
+ )
720
673
721
674
return named_tensor
722
675
@@ -837,9 +790,7 @@ def process_task_results(
837
790
task_results = []
838
791
839
792
for named_tensor in named_tensors :
840
- # quite a bit happens in here, including decompression, delta
841
- # handling, etc...
842
- tensor_key , value = self ._process_named_tensor (named_tensor , collaborator_name )
793
+ tensor_key , value = self .deserialize_tensor (named_tensor , collaborator_name )
843
794
844
795
if "metric" in tensor_key .tags :
845
796
# Caution: This schema must be followed. It is also used in
@@ -881,26 +832,17 @@ def _end_of_round_with_stragglers_check(self):
881
832
logger .warning (f"Identified stragglers: { self .stragglers } " )
882
833
self ._end_of_round_check ()
883
834
884
- def _process_named_tensor (self , named_tensor , collaborator_name ):
885
- """Extract the named tensor fields .
835
+ def deserialize_tensor (self , named_tensor , collaborator_name ):
836
+ """Deserialize a `NamedTensor` to a numpy array .
886
837
887
- Performs decompression, delta computation, and inserts results into
888
- TensorDB.
838
+ This function also performs decompresssion.
889
839
890
840
Args:
891
- named_tensor (protobuf NamedTensor): Named tensor.
892
- protobuf that will be extracted from and processed
893
- collaborator_name (str): Collaborator name.
894
- Collaborator name is needed for proper tagging of resulting
895
- tensorkeys.
841
+ named_tensor (protobuf): The tensor to convert to nparray.
896
842
897
843
Returns:
898
- tensor_key (TensorKey): Tensor key.
899
- The tensorkey extracted from the protobuf.
900
- nparray (np.array): Numpy array.
901
- The numpy array associated with the returned tensorkey.
844
+ A tuple (TensorKey, nparray).
902
845
"""
903
- raw_bytes = named_tensor .data_bytes
904
846
metadata = [
905
847
{
906
848
"int_to_float" : proto .int_to_float ,
@@ -918,59 +860,19 @@ def _process_named_tensor(self, named_tensor, collaborator_name):
918
860
named_tensor .report ,
919
861
tuple (named_tensor .tags ),
920
862
)
921
- tensor_name , origin , round_number , report , tags = tensor_key
922
863
923
- assert "compressed" in tags or "lossy_compressed" in tags , (
924
- f"Named tensor { tensor_key } is not compressed"
864
+ tensor_key , nparray = self .tensor_codec .decompress (
865
+ tensor_key ,
866
+ data = named_tensor .data_bytes ,
867
+ transformer_metadata = metadata ,
868
+ require_lossless = named_tensor .lossless ,
925
869
)
926
- if "compressed" in tags :
927
- dec_tk , decompressed_nparray = self .tensor_codec .decompress (
928
- tensor_key ,
929
- data = raw_bytes ,
930
- transformer_metadata = metadata ,
931
- require_lossless = True ,
932
- )
933
- dec_name , dec_origin , dec_round_num , dec_report , dec_tags = dec_tk
934
- # Need to add the collaborator tag to the resulting tensor
935
- new_tags = change_tags (dec_tags , add_field = collaborator_name )
936
-
937
- # layer.agg.n.trained.delta.col_i
938
- decompressed_tensor_key = TensorKey (
939
- dec_name , dec_origin , dec_round_num , dec_report , new_tags
940
- )
941
- if "lossy_compressed" in tags :
942
- dec_tk , decompressed_nparray = self .tensor_codec .decompress (
943
- tensor_key ,
944
- data = raw_bytes ,
945
- transformer_metadata = metadata ,
946
- require_lossless = False ,
947
- )
948
- dec_name , dec_origin , dec_round_num , dec_report , dec_tags = dec_tk
949
- new_tags = change_tags (dec_tags , add_field = collaborator_name )
950
- # layer.agg.n.trained.delta.lossy_decompressed.col_i
951
- decompressed_tensor_key = TensorKey (
952
- dec_name , dec_origin , dec_round_num , dec_report , new_tags
953
- )
954
-
955
- if "delta" in tags :
956
- base_model_tensor_key = TensorKey (tensor_name , origin , round_number , report , ("model" ,))
957
- base_model_nparray = self .tensor_db .get_tensor_from_cache (base_model_tensor_key )
958
- if base_model_nparray is None :
959
- raise ValueError (f"Base model { base_model_tensor_key } not present in TensorDB" )
960
- final_tensor_key , final_nparray = self .tensor_codec .apply_delta (
961
- decompressed_tensor_key ,
962
- decompressed_nparray ,
963
- base_model_nparray ,
964
- )
965
- else :
966
- final_tensor_key = decompressed_tensor_key
967
- final_nparray = decompressed_nparray
870
+ updated_tags = change_tags (tensor_key .tags , add_field = collaborator_name )
871
+ tensor_key = tensor_key ._replace (tags = updated_tags )
968
872
969
- assert final_nparray is not None , f"Could not create tensorkey { final_tensor_key } "
970
- self .tensor_db .cache_tensor ({final_tensor_key : final_nparray })
971
- logger .debug ("Created TensorKey: %s" , final_tensor_key )
873
+ self .tensor_db .cache_tensor ({tensor_key : nparray })
972
874
973
- return final_tensor_key , final_nparray
875
+ return tensor_key , nparray
974
876
975
877
def _prepare_trained (self , tensor_name , origin , round_number , report , agg_results ):
976
878
"""Prepare aggregated tensorkey tags.
@@ -982,82 +884,13 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
982
884
report (bool): Whether to report.
983
885
agg_results (np.array): Aggregated results.
984
886
"""
985
- # The aggregated tensorkey tags should have the form of
986
- # 'trained' or 'trained.lossy_decompressed'
987
- # They need to be relabeled to 'aggregated' and
988
- # reinserted. Then delta performed, compressed, etc.
989
- # then reinserted to TensorDB with 'model' tag
990
-
991
- # First insert the aggregated model layer with the
992
- # correct tensorkey
993
887
agg_tag_tk = TensorKey (tensor_name , origin , round_number + 1 , report , ("aggregated" ,))
994
888
self .tensor_db .cache_tensor ({agg_tag_tk : agg_results })
995
889
996
- # Create delta and save it in TensorDB
997
- base_model_tk = TensorKey (tensor_name , origin , round_number , report , ("model" ,))
998
- base_model_nparray = self .tensor_db .get_tensor_from_cache (base_model_tk )
999
- if base_model_nparray is not None and self .use_delta_updates :
1000
- delta_tk , delta_nparray = self .tensor_codec .generate_delta (
1001
- agg_tag_tk , agg_results , base_model_nparray
1002
- )
1003
- else :
1004
- # This condition is possible for base model
1005
- # optimizer states (i.e. Adam/iter:0, SGD, etc.)
1006
- # These values couldn't be present for the base
1007
- # model because no training occurs on the aggregator
1008
- delta_tk , delta_nparray = agg_tag_tk , agg_results
1009
-
1010
- # Compress lossless/lossy
1011
- compressed_delta_tk , compressed_delta_nparray , metadata = self .tensor_codec .compress (
1012
- delta_tk , delta_nparray
1013
- )
1014
-
1015
- # TODO extend the TensorDB so that compressed data is
1016
- # supported. Once that is in place
1017
- # the compressed delta can just be stored here instead
1018
- # of recreating it for every request
1019
-
1020
- # Decompress lossless/lossy
1021
- decompressed_delta_tk , decompressed_delta_nparray = self .tensor_codec .decompress (
1022
- compressed_delta_tk , compressed_delta_nparray , metadata
1023
- )
1024
-
1025
- self .tensor_db .cache_tensor ({decompressed_delta_tk : decompressed_delta_nparray })
1026
-
1027
- # Apply delta (unless delta couldn't be created)
1028
- if base_model_nparray is not None and self .use_delta_updates :
1029
- logger .debug ("Applying delta for layer %s" , decompressed_delta_tk [0 ])
1030
- new_model_tk , new_model_nparray = self .tensor_codec .apply_delta (
1031
- decompressed_delta_tk ,
1032
- decompressed_delta_nparray ,
1033
- base_model_nparray ,
1034
- )
1035
- else :
1036
- new_model_tk , new_model_nparray = (
1037
- decompressed_delta_tk ,
1038
- decompressed_delta_nparray ,
1039
- )
1040
-
1041
- # Now that the model has been compressed/decompressed
1042
- # with delta operations,
1043
- # Relabel the tags to 'model'
1044
- (
1045
- new_model_tensor_name ,
1046
- new_model_origin ,
1047
- new_model_round_number ,
1048
- new_model_report ,
1049
- new_model_tags ,
1050
- ) = new_model_tk
1051
- final_model_tk = TensorKey (
1052
- new_model_tensor_name ,
1053
- new_model_origin ,
1054
- new_model_round_number ,
1055
- new_model_report ,
1056
- ("model" ,),
1057
- )
1058
- self .next_model_round_number = new_model_round_number
1059
- # Finally, cache the updated model tensor
1060
- self .tensor_db .cache_tensor ({final_model_tk : new_model_nparray })
890
+ # Relabel the tags to 'model' and cache the updated model tensor
891
+ final_model_tk = agg_tag_tk ._replace (tags = ("model" ,))
892
+ self .next_model_round_number = final_model_tk .round_number
893
+ self .tensor_db .cache_tensor ({final_model_tk : agg_results })
1061
894
1062
895
def _compute_validation_related_task_metrics (self , task_name ) -> dict :
1063
896
"""Compute all validation related metrics.
0 commit comments