@@ -741,61 +741,97 @@ def post_task_results():
741
741
def _setup_tensor_route (self ):
742
742
"""Set up the /tensors/aggregated endpoint."""
743
743
744
- @self .app .route (f"/{ self .api_prefix } /tensors/aggregated" , methods = ["GET " ])
745
- def get_aggregated_tensor ():
746
- """Endpoint for collaborators to retrieve an aggregated tensor ."""
744
+ @self .app .route (f"/{ self .api_prefix } /tensors/aggregated/batch " , methods = ["POST " ])
745
+ def get_aggregated_tensors ():
746
+ """Endpoint for collaborators to retrieve multiple aggregated tensors ."""
747
747
start_time = time .time ()
748
748
749
749
# Validate that this endpoint is not used in connector mode
750
750
if self .use_connector :
751
- abort (501 , "GetAggregatedTensor not supported in connector mode" )
751
+ abort (501 , "GetAggregatedTensors not supported in connector mode" )
752
752
753
- # Get and validate collaborator identity
754
- collaborator_id = request .args .get ("collaborator_id" )
755
- federation_id = request .args .get ("federation_uuid" )
753
+ try :
754
+ # Parse the incoming JSON to a GetAggregatedTensorsRequest protobuf message
755
+ request_data = request .get_json ()
756
+ if not request_data :
757
+ abort (400 , "Invalid JSON payload" )
756
758
757
- # Use the consolidated validation method
758
- self . _is_authorized ( collaborator_id , federation_id )
759
+ tensors_request = aggregator_pb2 . GetAggregatedTensorsRequest ()
760
+ json_format . ParseDict ( request_data , tensors_request , ignore_unknown_fields = True )
759
761
760
- # Extract tensor request parameters
761
- tensor_name = request .args .get ("tensor_name" )
762
- try :
763
- round_number = int (request .args .get ("round_number" , 0 ))
764
- except (TypeError , ValueError ):
765
- abort (400 , "Invalid round number" )
766
- report = request .args .get ("report" , "" ).lower () == "true"
767
- tags = request .args .getlist ("tags" )
768
- require_lossless = request .args .get ("require_lossless" , "" ).lower () == "true"
769
-
770
- # Get the tensor from aggregator - direct delegation to the aggregator
771
- named_tensor = self .aggregator .get_aggregated_tensor (
772
- tensor_name ,
773
- round_number ,
774
- report = report ,
775
- tags = tuple (tags ),
776
- require_lossless = require_lossless ,
777
- requested_by = collaborator_id ,
778
- )
762
+ # Validate headers and get collaborator identity
763
+ collaborator_id = tensors_request .header .sender
764
+ federation_id = tensors_request .header .federation_uuid
779
765
780
- # Create response header using the standardized method
781
- header = create_header (
782
- sender = str (self .aggregator .uuid ),
783
- receiver = collaborator_id ,
784
- federation_uuid = str (self .aggregator .federation_uuid ),
785
- single_col_cert_common_name = self .aggregator .single_col_cert_common_name or "" ,
786
- )
766
+ # Use the consolidated validation method
767
+ self ._is_authorized (collaborator_id , federation_id )
787
768
788
- # Create response with empty tensor if not found
789
- response_proto = aggregator_pb2 .GetAggregatedTensorResponse (
790
- header = header ,
791
- round_number = round_number ,
792
- tensor = named_tensor
793
- if named_tensor is not None
794
- else aggregator_pb2 .NamedTensorProto (),
795
- )
769
+ # Validate request header similar to gRPC implementation
770
+ assert tensors_request .header .receiver == str (self .aggregator .uuid ), (
771
+ f"Header receiver mismatch. Expected: { self .aggregator .uuid } , "
772
+ f"Got: { tensors_request .header .receiver } "
773
+ )
796
774
797
- logger .debug (f"Tensor retrieval completed in { time .time () - start_time :.2f} seconds" )
798
- return jsonify (json_format .MessageToDict (response_proto ))
775
+ assert tensors_request .header .federation_uuid == str (
776
+ self .aggregator .federation_uuid
777
+ ), (
778
+ f"Federation UUID mismatch. Expected: { self .aggregator .federation_uuid } , "
779
+ f"Got: { tensors_request .header .federation_uuid } "
780
+ )
781
+
782
+ expected_cn = self .aggregator .single_col_cert_common_name or ""
783
+ assert tensors_request .header .single_col_cert_common_name == expected_cn , (
784
+ f"Single col cert CN mismatch. Expected: { expected_cn } , "
785
+ f"Got: { tensors_request .header .single_col_cert_common_name } "
786
+ )
787
+
788
+ # Get tensors from aggregator - similar to gRPC implementation
789
+ logger .debug (
790
+ f"Processing batch request for { len (tensors_request .tensor_specs )} tensors"
791
+ )
792
+
793
+ named_tensors = []
794
+ for ts in tensors_request .tensor_specs :
795
+ named_tensor = self .aggregator .get_aggregated_tensor (
796
+ ts .tensor_name ,
797
+ ts .round_number ,
798
+ ts .report ,
799
+ tuple (ts .tags ),
800
+ ts .require_lossless ,
801
+ collaborator_id ,
802
+ )
803
+ # Add tensor to list (None tensors will be handled by the client)
804
+ if named_tensor is not None :
805
+ named_tensors .append (named_tensor )
806
+ else :
807
+ # Add empty tensor placeholder to maintain order
808
+ named_tensors .append (aggregator_pb2 .NamedTensorProto ())
809
+
810
+ # Create response header using the standardized method
811
+ header = create_header (
812
+ sender = str (self .aggregator .uuid ),
813
+ receiver = collaborator_id ,
814
+ federation_uuid = str (self .aggregator .federation_uuid ),
815
+ single_col_cert_common_name = self .aggregator .single_col_cert_common_name or "" ,
816
+ )
817
+
818
+ # Create response
819
+ response_proto = aggregator_pb2 .GetAggregatedTensorsResponse (
820
+ header = header , tensors = named_tensors
821
+ )
822
+
823
+ logger .debug (
824
+ f"Batch tensor retrieval completed in { time .time () - start_time :.2f} seconds. "
825
+ f"Returned { len (named_tensors )} tensors"
826
+ )
827
+ return jsonify (json_format .MessageToDict (response_proto ))
828
+
829
+ except AssertionError as e :
830
+ logger .error (f"Header validation failed: { str (e )} " )
831
+ abort (400 , str (e ))
832
+ except Exception as e :
833
+ logger .error (f"Error processing batch tensor request: { str (e )} " )
834
+ abort (400 , f"Error processing batch tensor request: { str (e )} " )
799
835
800
836
def _setup_relay_route (self ):
801
837
"""Set up the /interop/relay endpoint."""
0 commit comments