@@ -627,24 +627,6 @@ def get_aggregated_tensor(
627
627
raise ValueError (f"Aggregator does not have `{ tensor_key } `" )
628
628
629
629
# Serialize (and compress) the tensor
630
- named_tensor = self .serialize_tensor (tensor_key , nparray , lossless = require_lossless )
631
- return named_tensor
632
-
633
- def serialize_tensor (self , tensor_key , nparray , lossless : bool ):
634
- """Serialize the tensor.
635
-
636
- This function also performs compression.
637
-
638
- Args:
639
- tensor_key (namedtuple): A TensorKey.
640
- nparray: A NumPy array associated with the requested
641
- tensor key.
642
- lossless: Whether to use lossless compression.
643
-
644
- Returns:
645
- named_tensor (protobuf) : The tensor constructed from the nparray.
646
- """
647
- # Secure aggregation setup tensor.
648
630
if "secagg" in tensor_key .tags :
649
631
import numpy as np
650
632
@@ -663,14 +645,9 @@ def default(self, obj):
663
645
664
646
return named_tensor
665
647
666
- tensor_key , nparray , metadata = self .tensor_codec .compress (tensor_key , nparray , lossless )
667
- named_tensor = utils .construct_named_tensor (
668
- tensor_key ,
669
- nparray ,
670
- metadata ,
671
- lossless ,
648
+ named_tensor = utils .serialize_tensor (
649
+ tensor_key , nparray , self .tensor_codec , lossless = require_lossless
672
650
)
673
-
674
651
return named_tensor
675
652
676
653
def _collaborator_task_completed (self , collaborator , task_name , round_num ):
@@ -769,43 +746,47 @@ def process_task_results(
769
746
self ._is_collaborator_done (collaborator_name , round_number )
770
747
self ._end_of_round_with_stragglers_check ()
771
748
772
- task_key = TaskResultKey (task_name , collaborator_name , round_number )
773
-
774
- # we mustn't have results already
775
749
if self ._collaborator_task_completed (collaborator_name , task_name , round_number ):
776
750
logger .warning (
777
- f"Aggregator already has task results from collaborator { collaborator_name } "
778
- f" for task { task_key } "
751
+ f"Aggregator already has task results from collaborator { collaborator_name } "
752
+ f"for task { task_name } in round { round_number } . Ignoring... "
779
753
)
780
754
return
781
755
782
- # By giving task_key it's own weight, we can support different
783
- # training/validation weights
784
- # As well as eventually supporting weights that change by round
785
- # (if more data is added)
756
+ # Record collaborator individual weightage/contribution for federated averaging
757
+ task_key = TaskResultKey (task_name , collaborator_name , round_number )
786
758
self .collaborator_task_weight [task_key ] = data_size
787
759
788
- # initialize the list of tensors that go with this task
789
- # Setting these incrementally is leading to missing values
760
+ # Process named tensors
790
761
task_results = []
791
-
762
+ result_tensor_dict = {}
792
763
for named_tensor in named_tensors :
793
- tensor_key , value = self .deserialize_tensor (named_tensor , collaborator_name )
764
+ # Deserialize
765
+ tensor_key , nparray = utils .deserialize_tensor (named_tensor , self .tensor_codec )
766
+
767
+ # Update origin/tags
768
+ updated_tags = change_tags (tensor_key .tags , add_field = collaborator_name )
769
+ tensor_key = tensor_key ._replace (origin = self .uuid , tags = updated_tags )
770
+
771
+ # Record
772
+ result_tensor_dict [tensor_key ] = nparray
773
+ task_results .append (tensor_key )
794
774
795
775
if "metric" in tensor_key .tags :
796
- # Caution: This schema must be followed. It is also used in
797
- # gRPC message streams for director/envoy.
776
+ assert nparray .ndim == 0 , (
777
+ f"Expected metric to be a scalar, got shape { nparray .shape } "
778
+ )
798
779
metrics = {
799
780
"round" : round_number ,
800
781
"metric_origin" : collaborator_name ,
801
782
"task_name" : task_name ,
802
783
"metric_name" : tensor_key .tensor_name ,
803
- "metric_value" : float (value ),
784
+ "metric_value" : float (nparray ),
804
785
}
805
786
self .metric_queue .put (metrics )
806
787
807
- task_results . append ( tensor_key )
808
-
788
+ # Store results in TensorDB
789
+ self . tensor_db . cache_tensor ( result_tensor_dict )
809
790
self .collaborator_tasks_results [task_key ] = task_results
810
791
811
792
# Check if collaborator or round is done.
@@ -832,48 +813,6 @@ def _end_of_round_with_stragglers_check(self):
832
813
logger .warning (f"Identified stragglers: { self .stragglers } " )
833
814
self ._end_of_round_check ()
834
815
835
- def deserialize_tensor (self , named_tensor , collaborator_name ):
836
- """Deserialize a `NamedTensor` to a numpy array.
837
-
838
- This function also performs decompresssion.
839
-
840
- Args:
841
- named_tensor (protobuf): The tensor to convert to nparray.
842
-
843
- Returns:
844
- A tuple (TensorKey, nparray).
845
- """
846
- metadata = [
847
- {
848
- "int_to_float" : proto .int_to_float ,
849
- "int_list" : proto .int_list ,
850
- "bool_list" : proto .bool_list ,
851
- }
852
- for proto in named_tensor .transformer_metadata
853
- ]
854
- # The tensor has already been transferred to aggregator,
855
- # so the newly constructed tensor should have the aggregator origin
856
- tensor_key = TensorKey (
857
- named_tensor .name ,
858
- self .uuid ,
859
- named_tensor .round_number ,
860
- named_tensor .report ,
861
- tuple (named_tensor .tags ),
862
- )
863
-
864
- tensor_key , nparray = self .tensor_codec .decompress (
865
- tensor_key ,
866
- data = named_tensor .data_bytes ,
867
- transformer_metadata = metadata ,
868
- require_lossless = named_tensor .lossless ,
869
- )
870
- updated_tags = change_tags (tensor_key .tags , add_field = collaborator_name )
871
- tensor_key = tensor_key ._replace (tags = updated_tags )
872
-
873
- self .tensor_db .cache_tensor ({tensor_key : nparray })
874
-
875
- return tensor_key , nparray
876
-
877
816
def _prepare_trained (self , tensor_name , origin , round_number , report , agg_results ):
878
817
"""Prepare aggregated tensorkey tags.
879
818
0 commit comments