Skip to content

Commit c13a40b

Browse files
committed
Patch REST get_aggregated_tensor api for batched changes as per securefederatedai#1575
- rename the function to get_aggregated_tensors - adjust both grpc and rest client for name and signature change in api - for the protobuf changes of TensorSpeca and Batched response adjust both client and server in line with grpc - fix the test cases rebased 29th.May.1 Signed-off-by: Shailesh Pant <[email protected]>
1 parent a551fb6 commit c13a40b

File tree

4 files changed

+144
-84
lines changed

4 files changed

+144
-84
lines changed

openfl/protocols/aggregator_client_interface.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,10 @@ def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
2424
pass
2525

2626
@abstractmethod
27-
def get_aggregated_tensor(
27+
def get_aggregated_tensors(
2828
self,
29-
tensor_name: str,
30-
round_number: int,
31-
report: bool,
32-
tags: List[str],
33-
require_lossless: bool,
29+
tensor_keys,
30+
require_lossless: bool = True,
3431
) -> Any:
3532
"""
3633
Retrieves the aggregated tensor.

openfl/transport/grpc/aggregator_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010

1111
import grpc
1212

13-
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
14-
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
15-
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc
1613
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, base_pb2, utils
14+
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
1715
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
1816

1917
logger = logging.getLogger(__name__)

openfl/transport/rest/aggregator_client.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -443,46 +443,75 @@ def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
443443
)
444444
return tasks_resp.tasks, tasks_resp.round_number, tasks_resp.sleep_time, tasks_resp.quit
445445

446-
def get_aggregated_tensor(
446+
def get_aggregated_tensors(
447447
self,
448-
tensor_name: str,
449-
round_number: int,
450-
report: bool,
451-
tags: List[str],
452-
require_lossless: bool,
453-
) -> Any:
454-
"""Get aggregated tensor with proper security settings."""
455-
params = {
456-
"sender": self.collaborator_name,
457-
"receiver": self.aggregator_uuid,
458-
"federation_uuid": self.federation_uuid,
459-
"tensor_name": tensor_name,
460-
"round_number": round_number,
461-
"report": report,
462-
"tags": tags,
463-
"require_lossless": require_lossless,
464-
"collaborator_id": self.collaborator_name,
448+
tensor_keys,
449+
require_lossless: bool = True,
450+
) -> List[base_pb2.NamedTensor]:
451+
"""
452+
Get aggregated tensors from the aggregator.
453+
454+
Args:
455+
tensor_keys (list): A list of tensor keys to fetch from aggregator.
456+
require_lossless (bool): Whether lossless compression is required.
457+
458+
Returns:
459+
A list of `NamedTensor`s in the same order as requested.
460+
"""
461+
logger.debug(f"Requesting {len(tensor_keys)} aggregated tensors")
462+
463+
# Build the request payload similar to gRPC implementation
464+
tensor_specs = []
465+
for k in tensor_keys:
466+
tensor_specs.append(
467+
{
468+
"tensor_name": k.tensor_name,
469+
"round_number": k.round_number,
470+
"report": k.report,
471+
"tags": k.tags,
472+
"require_lossless": require_lossless,
473+
}
474+
)
475+
476+
request_data = {
477+
"header": {
478+
"sender": self.collaborator_name,
479+
"receiver": self.aggregator_uuid,
480+
"federation_uuid": self.federation_uuid,
481+
"single_col_cert_common_name": self.single_col_cert_common_name or "",
482+
},
483+
"tensor_specs": tensor_specs,
465484
}
466-
headers = {"Accept": "application/json", "Sender": self.collaborator_name}
467-
url = f"{self.base_url}/tensors/aggregated"
485+
486+
headers = {
487+
"Accept": "application/json",
488+
"Content-Type": "application/json",
489+
"Sender": self.collaborator_name,
490+
}
491+
url = f"{self.base_url}/tensors/aggregated/batch"
468492
extended_timeout = (30, 600) # 30 seconds connect, 10 minutes read timeout
493+
469494
try:
470-
logger.debug(f"Requesting aggregated tensor {tensor_name} for round {round_number}")
495+
logger.debug(f"Requesting batch of {len(tensor_keys)} aggregated tensors")
471496
response = self._make_request(
472-
"GET", url, params=params, headers=headers, timeout=extended_timeout
497+
"POST",
498+
url,
499+
data=json_format.MessageToJson(
500+
aggregator_pb2.GetAggregatedTensorsRequest(**request_data)
501+
),
502+
headers=headers,
503+
timeout=extended_timeout,
473504
)
474505
data = response.json()
475-
resp = aggregator_pb2.GetAggregatedTensorResponse()
506+
resp = aggregator_pb2.GetAggregatedTensorsResponse()
476507
json_format.ParseDict(data, resp, ignore_unknown_fields=True)
477-
logger.debug(f"Successfully retrieved tensor {tensor_name} for round {round_number}")
478-
return resp.tensor
508+
logger.debug(f"Successfully retrieved {len(resp.tensors)} aggregated tensors")
509+
return resp.tensors
479510
except requests.exceptions.HTTPError as e:
480511
if e.response.status_code == 404:
481-
# This is expected during round 0 or when tensor hasn't been aggregated yet
482-
logger.debug(
483-
f"No aggregated tensor found for {tensor_name} at round {round_number}"
484-
)
485-
return None
512+
# This is expected during round 0 or when tensors haven't been aggregated yet
513+
logger.debug("No aggregated tensors found for the requested tensor keys")
514+
return []
486515
raise
487516

488517
def send_local_task_results(

openfl/transport/rest/aggregator_server.py

Lines changed: 81 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -741,61 +741,97 @@ def post_task_results():
741741
def _setup_tensor_route(self):
742742
"""Set up the /tensors/aggregated endpoint."""
743743

744-
@self.app.route(f"/{self.api_prefix}/tensors/aggregated", methods=["GET"])
745-
def get_aggregated_tensor():
746-
"""Endpoint for collaborators to retrieve an aggregated tensor."""
744+
@self.app.route(f"/{self.api_prefix}/tensors/aggregated/batch", methods=["POST"])
745+
def get_aggregated_tensors():
746+
"""Endpoint for collaborators to retrieve multiple aggregated tensors."""
747747
start_time = time.time()
748748

749749
# Validate that this endpoint is not used in connector mode
750750
if self.use_connector:
751-
abort(501, "GetAggregatedTensor not supported in connector mode")
751+
abort(501, "GetAggregatedTensors not supported in connector mode")
752752

753-
# Get and validate collaborator identity
754-
collaborator_id = request.args.get("collaborator_id")
755-
federation_id = request.args.get("federation_uuid")
753+
try:
754+
# Parse the incoming JSON to a GetAggregatedTensorsRequest protobuf message
755+
request_data = request.get_json()
756+
if not request_data:
757+
abort(400, "Invalid JSON payload")
756758

757-
# Use the consolidated validation method
758-
self._is_authorized(collaborator_id, federation_id)
759+
tensors_request = aggregator_pb2.GetAggregatedTensorsRequest()
760+
json_format.ParseDict(request_data, tensors_request, ignore_unknown_fields=True)
759761

760-
# Extract tensor request parameters
761-
tensor_name = request.args.get("tensor_name")
762-
try:
763-
round_number = int(request.args.get("round_number", 0))
764-
except (TypeError, ValueError):
765-
abort(400, "Invalid round number")
766-
report = request.args.get("report", "").lower() == "true"
767-
tags = request.args.getlist("tags")
768-
require_lossless = request.args.get("require_lossless", "").lower() == "true"
769-
770-
# Get the tensor from aggregator - direct delegation to the aggregator
771-
named_tensor = self.aggregator.get_aggregated_tensor(
772-
tensor_name,
773-
round_number,
774-
report=report,
775-
tags=tuple(tags),
776-
require_lossless=require_lossless,
777-
requested_by=collaborator_id,
778-
)
762+
# Validate headers and get collaborator identity
763+
collaborator_id = tensors_request.header.sender
764+
federation_id = tensors_request.header.federation_uuid
779765

780-
# Create response header using the standardized method
781-
header = create_header(
782-
sender=str(self.aggregator.uuid),
783-
receiver=collaborator_id,
784-
federation_uuid=str(self.aggregator.federation_uuid),
785-
single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "",
786-
)
766+
# Use the consolidated validation method
767+
self._is_authorized(collaborator_id, federation_id)
787768

788-
# Create response with empty tensor if not found
789-
response_proto = aggregator_pb2.GetAggregatedTensorResponse(
790-
header=header,
791-
round_number=round_number,
792-
tensor=named_tensor
793-
if named_tensor is not None
794-
else aggregator_pb2.NamedTensorProto(),
795-
)
769+
# Validate request header similar to gRPC implementation
770+
assert tensors_request.header.receiver == str(self.aggregator.uuid), (
771+
f"Header receiver mismatch. Expected: {self.aggregator.uuid}, "
772+
f"Got: {tensors_request.header.receiver}"
773+
)
796774

797-
logger.debug(f"Tensor retrieval completed in {time.time() - start_time:.2f} seconds")
798-
return jsonify(json_format.MessageToDict(response_proto))
775+
assert tensors_request.header.federation_uuid == str(
776+
self.aggregator.federation_uuid
777+
), (
778+
f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, "
779+
f"Got: {tensors_request.header.federation_uuid}"
780+
)
781+
782+
expected_cn = self.aggregator.single_col_cert_common_name or ""
783+
assert tensors_request.header.single_col_cert_common_name == expected_cn, (
784+
f"Single col cert CN mismatch. Expected: {expected_cn}, "
785+
f"Got: {tensors_request.header.single_col_cert_common_name}"
786+
)
787+
788+
# Get tensors from aggregator - similar to gRPC implementation
789+
logger.debug(
790+
f"Processing batch request for {len(tensors_request.tensor_specs)} tensors"
791+
)
792+
793+
named_tensors = []
794+
for ts in tensors_request.tensor_specs:
795+
named_tensor = self.aggregator.get_aggregated_tensor(
796+
ts.tensor_name,
797+
ts.round_number,
798+
ts.report,
799+
tuple(ts.tags),
800+
ts.require_lossless,
801+
collaborator_id,
802+
)
803+
# Add tensor to list (None tensors will be handled by the client)
804+
if named_tensor is not None:
805+
named_tensors.append(named_tensor)
806+
else:
807+
# Add empty tensor placeholder to maintain order
808+
named_tensors.append(aggregator_pb2.NamedTensorProto())
809+
810+
# Create response header using the standardized method
811+
header = create_header(
812+
sender=str(self.aggregator.uuid),
813+
receiver=collaborator_id,
814+
federation_uuid=str(self.aggregator.federation_uuid),
815+
single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "",
816+
)
817+
818+
# Create response
819+
response_proto = aggregator_pb2.GetAggregatedTensorsResponse(
820+
header=header, tensors=named_tensors
821+
)
822+
823+
logger.debug(
824+
f"Batch tensor retrieval completed in {time.time() - start_time:.2f} seconds. "
825+
f"Returned {len(named_tensors)} tensors"
826+
)
827+
return jsonify(json_format.MessageToDict(response_proto))
828+
829+
except AssertionError as e:
830+
logger.error(f"Header validation failed: {str(e)}")
831+
abort(400, str(e))
832+
except Exception as e:
833+
logger.error(f"Error processing batch tensor request: {str(e)}")
834+
abort(400, f"Error processing batch tensor request: {str(e)}")
799835

800836
def _setup_relay_route(self):
801837
"""Set up the /interop/relay endpoint."""

0 commit comments

Comments
 (0)