Skip to content

Commit cdbe2d6

Browse files
Revert "Disable streaming on send task results for consistency"
This reverts commit f6ac53e. Signed-off-by: Shah, Karan <[email protected]>
1 parent 3c528ab commit cdbe2d6

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

openfl/protocols/aggregator.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ service Aggregator {
1212
rpc Ping(PingRequest) returns (PingResponse) {}
1313
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {}
1414
rpc GetAggregatedTensors(GetAggregatedTensorsRequest) returns (GetAggregatedTensorsResponse) {}
15-
rpc SendLocalTaskResults(TaskResults) returns (SendLocalTaskResultsResponse) {}
15+
rpc SendLocalTaskResults(stream DataStream) returns (SendLocalTaskResultsResponse) {}
1616
rpc InteropRelay(InteropMessage) returns (InteropMessage) {}
1717
}
1818

openfl/transport/grpc/aggregator_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import grpc
1212

13-
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc
13+
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
1414
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
1515

1616
logger = logging.getLogger(__name__)
@@ -416,7 +416,8 @@ def send_local_task_results(
416416
tensors=named_tensors,
417417
)
418418

419-
response = self.stub.SendLocalTaskResults(request)
419+
# convert (potentially) long list of tensors into stream
420+
response = self.stub.SendLocalTaskResults(utils.proto_to_datastream(request))
420421
self.validate_response(response)
421422

422423
@_atomic_connection

openfl/transport/grpc/aggregator_server.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import grpc
1212

13-
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc
13+
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
1414
from openfl.transport.grpc.common import create_grpc_server, create_header, synchronized
1515

1616
logger = logging.getLogger(__name__)
@@ -256,7 +256,7 @@ def GetAggregatedTensors(self, request, context):
256256

257257
@synchronized
258258
def SendLocalTaskResults(self, request, context): # NOQA:N802
259-
"""Store collaborator task results on the aggregator.
259+
"""Request a model download from aggregator.
260260
261261
This method handles a request from a collaborator to send the results
262262
of a local task.
@@ -270,14 +270,23 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802
270270
aggregator_pb2.SendLocalTaskResultsResponse: The response to the
271271
request.
272272
"""
273-
self.validate_collaborator(request, context)
274-
self.check_request(request)
273+
try:
274+
proto = aggregator_pb2.TaskResults()
275+
proto = utils.datastream_to_proto(proto, request)
276+
except RuntimeError:
277+
raise RuntimeError(
278+
"Empty stream message, reestablishing connection from client to resume training..."
279+
)
275280

276-
collaborator_name = request.header.sender
277-
task_name = request.task_name
278-
round_number = request.round_number
279-
data_size = request.data_size
280-
named_tensors = request.tensors
281+
self.validate_collaborator(proto, context)
282+
# all messages get sanity checked
283+
self.check_request(proto)
284+
285+
collaborator_name = proto.header.sender
286+
task_name = proto.task_name
287+
round_number = proto.round_number
288+
data_size = proto.data_size
289+
named_tensors = proto.tensors
281290
self.aggregator.send_local_task_results(
282291
collaborator_name, round_number, task_name, data_size, named_tensors
283292
)

0 commit comments

Comments
 (0)