7
7
import inspect
8
8
import itertools
9
9
from types import MethodType
10
+ from typing import List
10
11
11
12
import numpy as np
12
13
@@ -95,6 +96,32 @@ def filter_attributes(ctx, f, **kwargs):
95
96
elif "exclude" in kwargs :
96
97
_process_exclusion (ctx , cls_attrs , kwargs ["exclude" ], f )
97
98
99
+ # This function validates the data types of the attributes in kwargs. 'reserved_words' parameter
100
+ # can be passed as a list of strings that should be accepted as 'str' attributes.
101
+ def validate_data_types (
102
+ prohibited_data_types : List [str ], reserved_words = ["collaborators" ], ** kwargs
103
+ ):
104
+ """Validates that the types of attributes in kwargs are not among the prohibited data types.
105
+ Raises a TypeError if any prohibited data type is found.
106
+
107
+ Args:
108
+ prohibited_data_types (List[str]): A list of prohibited data type names
109
+ (e.g., ['int', 'float']).
110
+ kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
111
+
112
+ Raises:
113
+ TypeError: If any prohibited data types are found in kwargs.
114
+ ValueError: If prohibited_data_types is empty.
115
+ """
116
+ if not prohibited_data_types :
117
+ raise ValueError ("prohibited_data_types must not be empty." )
118
+ for attr_name , attr_value in kwargs .items ():
119
+ if type (attr_value ).__name__ in prohibited_data_types and attr_value not in reserved_words :
120
+ raise TypeError (
121
+ f"The attribute '{ attr_name } ' = '{ attr_value } ' has a prohibited value type: "
122
+ f"{ type (attr_value ).__name__ } "
123
+ )
124
+
98
125
99
126
def _validate_include_exclude (kwargs , cls_attrs ):
100
127
"""Validates that 'include' and 'exclude' are not both present, and that
@@ -152,13 +179,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152
179
delattr (ctx , attr )
153
180
154
181
155
- def checkpoint (ctx , parent_func , chkpnt_reserved_words = ["next" , "runtime" ]):
182
+ def checkpoint (ctx , parent_func , checkpoint_reserved_words = ["next" , "runtime" ]):
156
183
"""Optionally saves the current state for the task just executed.
157
184
158
185
Args:
159
186
ctx (any): The context to checkpoint.
160
187
parent_func (function): The function that was just executed.
161
- chkpnt_reserved_words (list, optional): A list of reserved words to
188
+ checkpoint_reserved_words (list, optional): A list of reserved words to
162
189
exclude from checkpointing. Defaults to ["next", "runtime"].
163
190
164
191
Returns:
@@ -173,7 +200,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173
200
if ctx ._checkpoint :
174
201
# all objects will be serialized using Metaflow interface
175
202
print (f"Saving data artifacts for { parent_func .__name__ } " )
176
- artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = chkpnt_reserved_words )
203
+ artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = checkpoint_reserved_words )
177
204
task_id = ctx ._metaflow_interface .create_task (parent_func .__name__ )
178
205
ctx ._metaflow_interface .save_artifacts (
179
206
artifacts_iter (),
@@ -188,15 +215,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188
215
189
216
def old_check_resource_allocation (num_gpus , each_participant_gpu_usage ):
190
217
remaining_gpu_memory = {}
191
- # TODO for each GPU the funtion tries see if all participant usages fit
218
+ # TODO for each GPU the function tries see if all participant usages fit
192
219
# into a GPU, it it doesn't it removes that participant from the
193
220
# participant list, and adds it to the remaining_gpu_memory dict. So any
194
221
# sum of GPU requirements above 1 triggers this.
195
- # But at this point the funtion will raise an error because
222
+ # But at this point the function will raise an error because
196
223
# remaining_gpu_memory is never cleared.
197
224
# The participant list should remove the participant if it fits in the gpu
198
- # and save the partipant if it doesn't and continue to the next GPU to see
199
- # if it fits in that one, only if we run out of GPUs should this funtion
225
+ # and save the participant if it doesn't and continue to the next GPU to see
226
+ # if it fits in that one, only if we run out of GPUs should this function
200
227
# raise an error.
201
228
for gpu in np .ones (num_gpus , dtype = int ):
202
229
for i , (participant_name , participant_gpu_usage ) in enumerate (
@@ -230,7 +257,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230
257
if gpu == 0 :
231
258
break
232
259
if gpu < participant_gpu_usage :
233
- # participant doesn't fitm break to next GPU
260
+ # participant doesn't fit, break to next GPU
234
261
break
235
262
else :
236
263
# if participant fits remove from need_assigned
0 commit comments