Skip to content

Commit 439a6a7

Browse files
Generic ndarray transformer
Signed-off-by: Shah, Karan <[email protected]>
1 parent f9fe941 commit 439a6a7

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

openfl/pipelines/no_compression_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
"""NoCompressionPipeline module."""
66

7-
from openfl.pipelines.pipeline import Float32NumpyArrayToBytes, TransformationPipeline
7+
from openfl.pipelines.pipeline import NumpyArrayToBytes, TransformationPipeline
88

99

1010
class NoCompressionPipeline(TransformationPipeline):
1111
"""The data pipeline without any compression."""
1212

1313
def __init__(self, **kwargs):
1414
"""Initialize."""
15-
super().__init__(transformers=[Float32NumpyArrayToBytes()], **kwargs)
15+
super().__init__(transformers=[NumpyArrayToBytes()], **kwargs)

openfl/pipelines/pipeline.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,44 @@ def backward(self, data, metadata, **kwargs):
9393
return np.reshape(flat_array, newshape=array_shape, order="C")
9494

9595

96+
class NumpyArrayToBytes(Transformer):
97+
"""Transformer for converting generic Numpy arrays to bytes."""
98+
99+
def __init__(self):
100+
self.lossy = False
101+
102+
def forward(self, data: np.ndarray, **kwargs):
103+
"""Convert a Numpy array to bytes.
104+
105+
Args:
106+
data: The Numpy array to be converted.
107+
**kwargs: Additional keyword arguments for the conversion.
108+
109+
Returns:
110+
data_bytes: The data converted to bytes.
111+
metadata: The metadata for the conversion.
112+
"""
113+
array_shape = data.shape
114+
metadata = {"int_list": list(array_shape), "dtype": str(data.dtype)}
115+
data_bytes = data.tobytes(order="C")
116+
return data_bytes, metadata
117+
118+
def backward(self, data, metadata, **kwargs):
119+
"""Convert bytes back to a Numpy array.
120+
121+
Args:
122+
data: The data in bytes.
123+
metadata: The metadata for the conversion.
124+
125+
Returns:
126+
The data converted back to a Numpy array.
127+
"""
128+
array_shape = tuple(metadata["int_list"])
129+
dtype = np.dtype(metadata["dtype"])
130+
flat_array = np.frombuffer(data, dtype=dtype)
131+
return np.reshape(flat_array, newshape=array_shape, order="C")
132+
133+
96134
class TransformationPipeline:
97135
"""Data Transformer Pipeline Class.
98136

openfl/protocols/base.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ message MetadataProto {
2222
map<int32, float> int_to_float = 1;
2323
repeated int32 int_list = 2;
2424
repeated bool bool_list = 3;
25+
repeated string dtype = 4;
2526
}
2627

2728
// handles large size data

0 commit comments

Comments
 (0)