Skip to content

Commit c915b16

Browse files
Simplify collaborator side send/receive with no delta calculations
Signed-off-by: Shah, Karan <[email protected]>
1 parent 5c149d2 commit c915b16

File tree

1 file changed

+31
-70
lines changed

1 file changed

+31
-70
lines changed

openfl/component/collaborator/collaborator.py

Lines changed: 31 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def fetch_tensors_from_aggregator(self, tensor_keys: List[TensorKey]):
283283
if len(tensor_keys) > 0:
284284
logger.info("Fetching %d tensors from the aggregator", len(tensor_keys))
285285
named_tensors = self.client.get_aggregated_tensors(tensor_keys, require_lossless=True)
286-
arrays = [self.named_tensor_to_nparray(named_tensor) for named_tensor in named_tensors]
287-
tensor_dict = dict(zip(tensor_keys, arrays))
286+
for tensor_key, named_tensor in zip(tensor_keys, named_tensors):
287+
tensor_dict[tensor_key] = self.deserialize_tensor(named_tensor)
288288

289289
self.tensor_db.cache_tensor(tensor_dict)
290290

@@ -299,8 +299,6 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
299299
Returns:
300300
A dictionary of reportable metrics of the current collaborator for the task.
301301
"""
302-
named_tensors = [self.nparray_to_named_tensor(k, v) for k, v in tensor_dict.items()]
303-
304302
# for general tasks, there may be no notion of data size to send.
305303
# But that raises the question how to properly aggregate results.
306304

@@ -323,6 +321,9 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
323321
value = float(tensor_dict[tensor])
324322
metrics.update({f"{self.collaborator_name}/{task_name}/{tensor_name}": value})
325323

324+
# Serialize tensors to be sent to the aggregator
325+
named_tensors = [self.serialize_tensor(k, v) for k, v in tensor_dict.items()]
326+
326327
self.client.send_local_task_results(
327328
round_number,
328329
task_name,
@@ -332,71 +333,44 @@ def send_task_results(self, tensor_dict, round_number, task_name) -> dict:
332333

333334
return metrics
334335

335-
def nparray_to_named_tensor(self, tensor_key, nparray):
336-
"""Construct the NamedTensor Protobuf.
336+
def serialize_tensor(self, tensor_key, nparray):
337+
"""Serialize the tensor.
337338
338-
Includes logic to create delta, compress tensors with the TensorCodec,
339-
etc.
339+
This function also performs compression.
340340
341341
Args:
342-
tensor_key (namedtuple): Tensorkey that will be resolved locally or
343-
remotely. May be the product of other tensors.
344-
nparray: The decompressed tensor associated with the requested
342+
tensor_key (namedtuple): A TensorKey.
343+
nparray: A NumPy array associated with the requested
345344
tensor key.
346345
347346
Returns:
348347
named_tensor (protobuf) : The tensor constructed from the nparray.
349348
"""
350-
# if we have an aggregated tensor, we can make a delta
351-
tensor_name, origin, round_number, report, tags = tensor_key
352-
if "trained" in tags and self.use_delta_updates:
353-
# Should get the pretrained model to create the delta. If training
354-
# has happened,
355-
# Model should already be stored in the TensorDB
356-
model_nparray = self.tensor_db.get_tensor_from_cache(
357-
TensorKey(tensor_name, origin, round_number, report, ("model",))
358-
)
359-
360-
# The original model will not be present for the optimizer on the
361-
# first round.
362-
if model_nparray is not None:
363-
delta_tensor_key, delta_nparray = self.tensor_codec.generate_delta(
364-
tensor_key, nparray, model_nparray
365-
)
366-
delta_comp_tensor_key, delta_comp_nparray, metadata = self.tensor_codec.compress(
367-
delta_tensor_key, delta_nparray
368-
)
369-
370-
named_tensor = utils.construct_named_tensor(
371-
delta_comp_tensor_key,
372-
delta_comp_nparray,
373-
metadata,
374-
lossless=False,
375-
)
376-
return named_tensor
377-
378-
# Assume every other tensor requires lossless compression
379-
compressed_tensor_key, compressed_nparray, metadata = self.tensor_codec.compress(
380-
tensor_key, nparray, require_lossless=True
349+
lossless = True
350+
tensor_key, nparray, metadata = self.tensor_codec.compress(
351+
tensor_key,
352+
nparray,
353+
lossless,
381354
)
382355
named_tensor = utils.construct_named_tensor(
383-
compressed_tensor_key, compressed_nparray, metadata, lossless=True
356+
tensor_key,
357+
nparray,
358+
metadata,
359+
lossless,
384360
)
385-
386361
return named_tensor
387362

388-
def named_tensor_to_nparray(self, named_tensor):
389-
"""Convert named tensor to a numpy array.
363+
def deserialize_tensor(self, named_tensor):
364+
"""Deserialize a `NamedTensor` to a numpy array.
365+
366+
This function also performs decompresssion.
390367
391368
Args:
392369
named_tensor (protobuf): The tensor to convert to nparray.
393370
394371
Returns:
395-
decompressed_nparray (nparray): The nparray converted.
372+
The converted nparray.
396373
"""
397-
# do the stuff we do now for decompression and frombuffer and stuff
398-
# This should probably be moved back to protoutils
399-
raw_bytes = named_tensor.data_bytes
400374
metadata = [
401375
{
402376
"int_to_float": proto.int_to_float,
@@ -414,28 +388,15 @@ def named_tensor_to_nparray(self, named_tensor):
414388
named_tensor.report,
415389
tuple(named_tensor.tags),
416390
)
417-
*_, tags = tensor_key
418-
if "compressed" in tags:
419-
decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress(
420-
tensor_key,
421-
data=raw_bytes,
422-
transformer_metadata=metadata,
423-
require_lossless=True,
424-
)
425-
elif "lossy_compressed" in tags:
426-
decompressed_tensor_key, decompressed_nparray = self.tensor_codec.decompress(
427-
tensor_key, data=raw_bytes, transformer_metadata=metadata
428-
)
429-
else:
430-
# There could be a case where the compression pipeline is bypassed
431-
# entirely
432-
logger.warning("Bypassing tensor codec...")
433-
decompressed_tensor_key = tensor_key
434-
decompressed_nparray = raw_bytes
435391

436-
self.tensor_db.cache_tensor({decompressed_tensor_key: decompressed_nparray})
392+
tensor_key, nparray = self.tensor_codec.decompress(
393+
tensor_key,
394+
data=named_tensor.data_bytes,
395+
transformer_metadata=metadata,
396+
require_lossless=named_tensor.lossless,
397+
)
437398

438-
return decompressed_nparray
399+
return nparray
439400

440401
def _apply_masks(
441402
self,

0 commit comments

Comments
 (0)