Skip to content

Commit 40155cd

Browse files
committed
Make LocalRuntime working with support for the exclusion of prohibited data types.
Signed-off-by: yuliasherman <[email protected]>
1 parent a208651 commit 40155cd

File tree

6 files changed

+77
-17
lines changed

6 files changed

+77
-17
lines changed

openfl/experimental/workflow/interface/fl_spec.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
filter_attributes,
2323
generate_artifacts,
2424
should_transfer,
25+
validate_data_types,
2526
)
2627

2728

@@ -343,8 +344,14 @@ def next(self, f, **kwargs) -> None:
343344
if aggregator_to_collaborator(f, parent_func):
344345
agg_to_collab_ss = self._capture_instance_snapshot(kwargs=kwargs)
345346

346-
# Remove included / excluded attributes from next task
347-
filter_attributes(self, f, **kwargs)
347+
# Remove prohibited attributes from the next task
348+
if kwargs:
349+
filter_attributes(self, f, **kwargs)
350+
if self._runtime._prohibited_data_types:
351+
# try:
352+
validate_data_types(self._runtime._prohibited_data_types, **kwargs)
353+
# except Exception as exception:
354+
# print(exception)
348355

349356
if str(self._runtime) == "FederatedRuntime":
350357
if f.collaborator_step and not f.aggregator_step:

openfl/experimental/workflow/runtime/federated_runtime.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
collaborators: Optional[List[str]] = None,
4343
director: Optional[Dict[str, Any]] = None,
4444
notebook_path: Optional[str] = None,
45+
prohibited_data_types: Optional[List[str]] = None,
4546
tls: bool = False,
4647
) -> None:
4748
"""Initializes the FederatedRuntime object.
@@ -51,9 +52,11 @@ def __init__(
5152
Defaults to None.
5253
director (Optional[Dict[str, Any]]): Director information. Defaults to None
5354
notebook_path (Optional[str]): Jupyter notebook path
55+
prohibited_data_types (List[str]): A list of data types that are not allowed to be sent
56+
through the network.
5457
tls (bool): Whether to use TLS for the connection.
5558
"""
56-
super().__init__()
59+
super().__init__(prohibited_data_types=prohibited_data_types)
5760
self.__collaborators = collaborators
5861

5962
self.tls = tls

openfl/experimental/workflow/runtime/local_runtime.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def __init__(
302302
aggregator: Dict = None,
303303
collaborators: Dict = None,
304304
backend: str = "single_process",
305+
prohibited_data_types: Optional[List[str]] = None,
305306
**kwargs,
306307
) -> None:
307308
"""Initializes the LocalRuntime object to run the flow on a single
@@ -314,6 +315,8 @@ def __init__(
314315
collaborators (List[Type[Collaborator]], optional): A list of
315316
collaborators; each with their own private attributes.
316317
backend (str, optional): The backend that will execute the tasks.
318+
prohibited_data_types (List[str]): A list of data types that are not allowed to be sent
319+
through the network.
317320
Defaults to "single_process".
318321
Available options are:
319322
- 'single_process': (default) Executes every task within the
@@ -324,7 +327,7 @@ def __init__(
324327
The RayGroups run concurrently while participants in the
325328
group run serially.
326329
The default is 1 RayGroup and can be changed by using the
327-
num_actors=1 kwarg. By using more RayGroups more concurency
330+
num_actors=1 kwarg. By using more RayGroups more concurrency
328331
is allowed with the trade off being that each RayGroup has
329332
extra memory overhead in the form of extra CUDA CONTEXTS.
330333
@@ -346,7 +349,7 @@ def some_collaborator_task(self):
346349
access. If the system has one GPU, collaborator tasks will run
347350
sequentially.
348351
"""
349-
super().__init__()
352+
super().__init__(prohibited_data_types=prohibited_data_types)
350353
if backend not in ["ray", "single_process"]:
351354
raise ValueError(
352355
f"Invalid 'backend' value '{backend}', accepted values are "
@@ -737,7 +740,9 @@ def execute_collab_task(self, flspec_obj, f, parent_func, instance_snapshot, **k
737740
# Set new LocalRuntime for clone as it is required
738741
# new runtime object will not contain private attributes of
739742
# aggregator or other collaborators
740-
clone.runtime = LocalRuntime(backend="single_process")
743+
clone.runtime = LocalRuntime(
744+
backend="single_process", prohibited_data_types=super().prohibited_data_types
745+
)
741746

742747
# write the clone to the object store
743748
# ensure clone is getting latest _metaflow_interface

openfl/experimental/workflow/runtime/runtime.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,37 @@
44

55
"""openfl.experimental.workflow.runtime module Runtime class."""
66

7-
from typing import Callable, List
7+
from typing import Callable, List, Optional
88

99
from openfl.experimental.workflow.interface.fl_spec import FLSpec
1010
from openfl.experimental.workflow.interface.participants import Aggregator, Collaborator
1111

1212

1313
class Runtime:
14-
def __init__(self):
14+
def __init__(self, prohibited_data_types: Optional[List[str]] = None):
1515
"""Initializes the Runtime object.
1616
1717
This serves as a base interface for runtimes that can run FLSpec flows.
18+
19+
Args:
20+
prohibited_data_types (Optional[List[str]]): A list of data types that are not allowed to be sent
21+
through the network. Defaults to an empty list if not provided.
1822
"""
19-
pass
23+
self.prohibited_data_types = prohibited_data_types or []
24+
25+
@property
26+
def prohibited_data_types(self) -> List[str]:
27+
"""Return the prohibited data types for the runtime."""
28+
return self._prohibited_data_types
29+
30+
@prohibited_data_types.setter
31+
def prohibited_data_types(self, value: List[str]):
32+
"""Set the prohibited data types for the runtime."""
33+
if not isinstance(value, list):
34+
raise TypeError("prohibited_data_types must be a list of strings.")
35+
if not all(isinstance(item, str) for item in value):
36+
raise ValueError("All items in prohibited_data_types must be strings.")
37+
self._prohibited_data_types = value
2038

2139
@property
2240
def aggregator(self):

openfl/experimental/workflow/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
filter_attributes,
1818
generate_artifacts,
1919
parse_attrs,
20+
validate_data_types,
2021
)
2122
from openfl.experimental.workflow.utilities.stream_redirect import (
2223
RedirectStdStream,

openfl/experimental/workflow/utilities/runtime_utils.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import inspect
88
import itertools
99
from types import MethodType
10+
from typing import List
1011

1112
import numpy as np
1213

@@ -96,6 +97,31 @@ def filter_attributes(ctx, f, **kwargs):
9697
_process_exclusion(ctx, cls_attrs, kwargs["exclude"], f)
9798

9899

100+
def validate_data_types(
101+
prohibited_data_types: List[str], reserved_words=["collaborators"], **kwargs
102+
):
103+
"""Validates that the types of attributes in kwargs are not among the prohibited data types.
104+
Raises a TypeError if any prohibited data type is found.
105+
106+
Args:
107+
prohibited_data_types (List[str]): A list of prohibited data type names
108+
(e.g., ['int', 'float']).
109+
kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
110+
111+
Raises:
112+
TypeError: If any prohibited data types are found in kwargs.
113+
ValueError: If prohibited_data_types is empty.
114+
"""
115+
if not prohibited_data_types:
116+
raise ValueError("prohibited_data_types must not be empty.")
117+
for attr_name, attr_value in kwargs.items():
118+
if type(attr_value).__name__ in prohibited_data_types and attr_value not in reserved_words:
119+
raise TypeError(
120+
f"The attribute '{attr_name}' = '{attr_value}' has a prohibited value type: "
121+
f"{type(attr_value).__name__}"
122+
)
123+
124+
99125
def _validate_include_exclude(kwargs, cls_attrs):
100126
"""Validates that 'include' and 'exclude' are not both present, and that
101127
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +178,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152178
delattr(ctx, attr)
153179

154180

155-
def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
181+
def checkpoint(ctx, parent_func, checkpoint_reserved_words=["next", "runtime"]):
156182
"""Optionally saves the current state for the task just executed.
157183
158184
Args:
159185
ctx (any): The context to checkpoint.
160186
parent_func (function): The function that was just executed.
161-
chkpnt_reserved_words (list, optional): A list of reserved words to
187+
checkpoint_reserved_words (list, optional): A list of reserved words to
162188
exclude from checkpointing. Defaults to ["next", "runtime"].
163189
164190
Returns:
@@ -173,7 +199,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173199
if ctx._checkpoint:
174200
# all objects will be serialized using Metaflow interface
175201
print(f"Saving data artifacts for {parent_func.__name__}")
176-
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=chkpnt_reserved_words)
202+
artifacts_iter, _ = generate_artifacts(ctx=ctx, reserved_words=checkpoint_reserved_words)
177203
task_id = ctx._metaflow_interface.create_task(parent_func.__name__)
178204
ctx._metaflow_interface.save_artifacts(
179205
artifacts_iter(),
@@ -188,15 +214,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188214

189215
def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
190216
remaining_gpu_memory = {}
191-
# TODO for each GPU the funtion tries see if all participant usages fit
217+
# TODO for each GPU the function tries see if all participant usages fit
192218
# into a GPU, it it doesn't it removes that participant from the
193219
# participant list, and adds it to the remaining_gpu_memory dict. So any
194220
# sum of GPU requirements above 1 triggers this.
195-
# But at this point the funtion will raise an error because
221+
# But at this point the function will raise an error because
196222
# remaining_gpu_memory is never cleared.
197223
# The participant list should remove the participant if it fits in the gpu
198-
# and save the partipant if it doesn't and continue to the next GPU to see
199-
# if it fits in that one, only if we run out of GPUs should this funtion
224+
# and save the participant if it doesn't and continue to the next GPU to see
225+
# if it fits in that one, only if we run out of GPUs should this function
200226
# raise an error.
201227
for gpu in np.ones(num_gpus, dtype=int):
202228
for i, (participant_name, participant_gpu_usage) in enumerate(
@@ -230,7 +256,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230256
if gpu == 0:
231257
break
232258
if gpu < participant_gpu_usage:
233-
# participant doesn't fitm break to next GPU
259+
# participant doesn't fit; break to next GPU
234260
break
235261
else:
236262
# if participant fits remove from need_assigned

0 commit comments

Comments
 (0)