@@ -283,8 +283,8 @@ def fetch_tensors_from_aggregator(self, tensor_keys: List[TensorKey]):
283
283
if len (tensor_keys ) > 0 :
284
284
logger .info ("Fetching %d tensors from the aggregator" , len (tensor_keys ))
285
285
named_tensors = self .client .get_aggregated_tensors (tensor_keys , require_lossless = True )
286
- arrays = [ self . named_tensor_to_nparray ( named_tensor ) for named_tensor in named_tensors ]
287
- tensor_dict = dict ( zip ( tensor_keys , arrays ) )
286
+ for tensor_key , named_tensor in zip ( tensor_keys , named_tensors ):
287
+ tensor_dict [ tensor_key ] = self . deserialize_tensor ( named_tensor )
288
288
289
289
self .tensor_db .cache_tensor (tensor_dict )
290
290
@@ -299,8 +299,6 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
299
299
Returns:
300
300
A dictionary of reportable metrics of the current collaborator for the task.
301
301
"""
302
- named_tensors = [self .nparray_to_named_tensor (k , v ) for k , v in tensor_dict .items ()]
303
-
304
302
# for general tasks, there may be no notion of data size to send.
305
303
# But that raises the question how to properly aggregate results.
306
304
@@ -323,6 +321,9 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
323
321
value = float (tensor_dict [tensor ])
324
322
metrics .update ({f"{ self .collaborator_name } /{ task_name } /{ tensor_name } " : value })
325
323
324
+ # Serialize tensors to be sent to the aggregator
325
+ named_tensors = [self .serialize_tensor (k , v ) for k , v in tensor_dict .items ()]
326
+
326
327
self .client .send_local_task_results (
327
328
round_number ,
328
329
task_name ,
@@ -332,71 +333,44 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
332
333
333
334
return metrics
334
335
335
- def nparray_to_named_tensor (self , tensor_key , nparray ):
336
- """Construct the NamedTensor Protobuf .
336
+ def serialize_tensor (self , tensor_key , nparray ):
337
+ """Serialize the tensor .
337
338
338
- Includes logic to create delta, compress tensors with the TensorCodec,
339
- etc.
339
+ This function also performs compression.
340
340
341
341
Args:
342
- tensor_key (namedtuple): Tensorkey that will be resolved locally or
343
- remotely. May be the product of other tensors.
344
- nparray: The decompressed tensor associated with the requested
342
+ tensor_key (namedtuple): A TensorKey.
343
+ nparray: A NumPy array associated with the requested
345
344
tensor key.
346
345
347
346
Returns:
348
347
named_tensor (protobuf) : The tensor constructed from the nparray.
349
348
"""
350
- # if we have an aggregated tensor, we can make a delta
351
- tensor_name , origin , round_number , report , tags = tensor_key
352
- if "trained" in tags and self .use_delta_updates :
353
- # Should get the pretrained model to create the delta. If training
354
- # has happened,
355
- # Model should already be stored in the TensorDB
356
- model_nparray = self .tensor_db .get_tensor_from_cache (
357
- TensorKey (tensor_name , origin , round_number , report , ("model" ,))
358
- )
359
-
360
- # The original model will not be present for the optimizer on the
361
- # first round.
362
- if model_nparray is not None :
363
- delta_tensor_key , delta_nparray = self .tensor_codec .generate_delta (
364
- tensor_key , nparray , model_nparray
365
- )
366
- delta_comp_tensor_key , delta_comp_nparray , metadata = self .tensor_codec .compress (
367
- delta_tensor_key , delta_nparray
368
- )
369
-
370
- named_tensor = utils .construct_named_tensor (
371
- delta_comp_tensor_key ,
372
- delta_comp_nparray ,
373
- metadata ,
374
- lossless = False ,
375
- )
376
- return named_tensor
377
-
378
- # Assume every other tensor requires lossless compression
379
- compressed_tensor_key , compressed_nparray , metadata = self .tensor_codec .compress (
380
- tensor_key , nparray , require_lossless = True
349
+ lossless = True
350
+ tensor_key , nparray , metadata = self .tensor_codec .compress (
351
+ tensor_key ,
352
+ nparray ,
353
+ lossless ,
381
354
)
382
355
named_tensor = utils .construct_named_tensor (
383
- compressed_tensor_key , compressed_nparray , metadata , lossless = True
356
+ tensor_key ,
357
+ nparray ,
358
+ metadata ,
359
+ lossless ,
384
360
)
385
-
386
361
return named_tensor
387
362
388
- def named_tensor_to_nparray (self , named_tensor ):
389
- """Convert named tensor to a numpy array.
363
+ def deserialize_tensor (self , named_tensor ):
364
+ """Deserialize a `NamedTensor` to a numpy array.
365
+
366
+ This function also performs decompresssion.
390
367
391
368
Args:
392
369
named_tensor (protobuf): The tensor to convert to nparray.
393
370
394
371
Returns:
395
- decompressed_nparray (nparray): The nparray converted.
372
+ The converted nparray .
396
373
"""
397
- # do the stuff we do now for decompression and frombuffer and stuff
398
- # This should probably be moved back to protoutils
399
- raw_bytes = named_tensor .data_bytes
400
374
metadata = [
401
375
{
402
376
"int_to_float" : proto .int_to_float ,
@@ -414,28 +388,15 @@ def named_tensor_to_nparray(self, named_tensor):
414
388
named_tensor .report ,
415
389
tuple (named_tensor .tags ),
416
390
)
417
- * _ , tags = tensor_key
418
- if "compressed" in tags :
419
- decompressed_tensor_key , decompressed_nparray = self .tensor_codec .decompress (
420
- tensor_key ,
421
- data = raw_bytes ,
422
- transformer_metadata = metadata ,
423
- require_lossless = True ,
424
- )
425
- elif "lossy_compressed" in tags :
426
- decompressed_tensor_key , decompressed_nparray = self .tensor_codec .decompress (
427
- tensor_key , data = raw_bytes , transformer_metadata = metadata
428
- )
429
- else :
430
- # There could be a case where the compression pipeline is bypassed
431
- # entirely
432
- logger .warning ("Bypassing tensor codec..." )
433
- decompressed_tensor_key = tensor_key
434
- decompressed_nparray = raw_bytes
435
391
436
- self .tensor_db .cache_tensor ({decompressed_tensor_key : decompressed_nparray })
392
+ tensor_key , nparray = self .tensor_codec .decompress (
393
+ tensor_key ,
394
+ data = named_tensor .data_bytes ,
395
+ transformer_metadata = metadata ,
396
+ require_lossless = named_tensor .lossless ,
397
+ )
437
398
438
- return decompressed_nparray
399
+ return nparray
439
400
440
401
def _apply_masks (
441
402
self ,
0 commit comments