10
10
11
11
import grpc
12
12
13
- from openfl .protocols import aggregator_pb2 , aggregator_pb2_grpc
13
+ from openfl .protocols import aggregator_pb2 , aggregator_pb2_grpc , utils
14
14
from openfl .transport .grpc .common import create_grpc_server , create_header , synchronized
15
15
16
16
logger = logging .getLogger (__name__ )
@@ -256,7 +256,7 @@ def GetAggregatedTensors(self, request, context):
256
256
257
257
@synchronized
258
258
def SendLocalTaskResults (self , request , context ): # NOQA:N802
259
- """Store collaborator task results on the aggregator.
259
+ """Request a model download from aggregator.
260
260
261
261
This method handles a request from a collaborator to send the results
262
262
of a local task.
@@ -270,14 +270,23 @@ def SendLocalTaskResults(self, request, context): # NOQA:N802
270
270
aggregator_pb2.SendLocalTaskResultsResponse: The response to the
271
271
request.
272
272
"""
273
- self .validate_collaborator (request , context )
274
- self .check_request (request )
273
+ try :
274
+ proto = aggregator_pb2 .TaskResults ()
275
+ proto = utils .datastream_to_proto (proto , request )
276
+ except RuntimeError :
277
+ raise RuntimeError (
278
+ "Empty stream message, reestablishing connection from client to resume training..."
279
+ )
275
280
276
- collaborator_name = request .header .sender
277
- task_name = request .task_name
278
- round_number = request .round_number
279
- data_size = request .data_size
280
- named_tensors = request .tensors
281
+ self .validate_collaborator (proto , context )
282
+ # all messages get sanity checked
283
+ self .check_request (proto )
284
+
285
+ collaborator_name = proto .header .sender
286
+ task_name = proto .task_name
287
+ round_number = proto .round_number
288
+ data_size = proto .data_size
289
+ named_tensors = proto .tensors
281
290
self .aggregator .send_local_task_results (
282
291
collaborator_name , round_number , task_name , data_size , named_tensors
283
292
)
0 commit comments