Skip to content

Commit 86658a0

Browse files
authored
Merge branch 'master' into hf-pt-inf
2 parents 3dae970 + fda438c commit 86658a0

File tree

8 files changed

+641
-6
lines changed

8 files changed

+641
-6
lines changed

src/sagemaker/local/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
# Environment variables to be set during training
5151
REGION_ENV_NAME = "AWS_REGION"
5252
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
53-
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
53+
S3_ENDPOINT_URL_ENV_NAME = "AWS_ENDPOINT_URL_S3"
5454
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"
5555

5656
# SELinux Enabled

src/sagemaker/modules/train/sm_recipes/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,13 @@ def _get_args_from_nova_recipe(
305305
)
306306
args["hyperparameters"]["kms_key"] = kms_key
307307

308+
# Handle eval custom lambda configuration
309+
if recipe.get("evaluation", {}):
310+
processor = recipe.get("processor", {})
311+
lambda_arn = processor.get("lambda_arn", "")
312+
if lambda_arn:
313+
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
314+
308315
_register_custom_resolvers()
309316

310317
# Resolve Final Recipe

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,13 @@ def _setup_for_nova_recipe(
12241224
)
12251225
args["hyperparameters"]["kms_key"] = kms_key
12261226

1227+
# Handle eval custom lambda configuration
1228+
if recipe.get("evaluation", {}):
1229+
processor = recipe.get("processor", {})
1230+
lambda_arn = processor.get("lambda_arn", "")
1231+
if lambda_arn:
1232+
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn
1233+
12271234
# Resolve and save the final recipe
12281235
self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"])
12291236

src/sagemaker/workflow/emr_step.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
from sagemaker.workflow.properties import (
2222
Properties,
2323
)
24+
from sagemaker.workflow.retry import StepRetryPolicy
2425
from sagemaker.workflow.step_collections import StepCollection
25-
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
26+
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig
2627

2728

2829
class EMRStepConfig:
@@ -110,8 +111,8 @@ def to_request(self) -> RequestType:
110111
)
111112

112113

113-
class EMRStep(Step):
114-
"""EMR step for workflow."""
114+
class EMRStep(ConfigurableRetryStep):
115+
"""EMR step for workflow with configurable retry policies."""
115116

116117
def _validate_cluster_config(self, cluster_config, step_name):
117118
"""Validates user provided cluster_config.
@@ -164,6 +165,7 @@ def __init__(
164165
cache_config: Optional[CacheConfig] = None,
165166
cluster_config: Optional[Dict[str, Any]] = None,
166167
execution_role_arn: Optional[str] = None,
168+
retry_policies: Optional[List[StepRetryPolicy]] = None,
167169
):
168170
"""Constructs an `EMRStep`.
169171
@@ -200,7 +202,14 @@ def __init__(
200202
called on the cluster specified by ``cluster_id``, so you can only include this
201203
field if ``cluster_id`` is not None.
202204
"""
203-
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
205+
super().__init__(
206+
name=name,
207+
step_type=StepTypeEnum.EMR,
208+
display_name=display_name,
209+
description=description,
210+
depends_on=depends_on,
211+
retry_policies=retry_policies,
212+
)
204213

205214
emr_step_args = {"StepConfig": step_config.to_request()}
206215
root_property = Properties(step_name=name, step=self, shape_name="Step", service_name="emr")
@@ -248,7 +257,7 @@ def properties(self) -> RequestType:
248257
return self._properties
249258

250259
def to_request(self) -> RequestType:
251-
"""Updates the dictionary with cache configuration."""
260+
"""Updates the dictionary with cache configuration and retry policies"""
252261
request_dict = super().to_request()
253262
if self.cache_config:
254263
request_dict.update(self.cache_config.config)

tests/integ/sagemaker/workflow/test_emr_steps.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
2121
from sagemaker.workflow.parameters import ParameterInteger
2222
from sagemaker.workflow.pipeline import Pipeline
23+
from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum
2324

2425

2526
@pytest.fixture
@@ -134,3 +135,215 @@ def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_
134135
pipeline.delete()
135136
except Exception:
136137
pass
138+
139+
140+
def test_emr_with_retry_policies(sagemaker_session, role, pipeline_name, region_name):
141+
"""Test EMR steps with retry policies in both cluster_id and cluster_config scenarios."""
142+
emr_step_config = EMRStepConfig(
143+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
144+
args=["dummy_emr_script_path"],
145+
)
146+
147+
retry_policies = [
148+
StepRetryPolicy(
149+
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
150+
interval_seconds=1,
151+
max_attempts=3,
152+
backoff_rate=2.0,
153+
)
154+
]
155+
156+
# Step with existing cluster and retry policies
157+
step_emr_1 = EMRStep(
158+
name="emr-step-1",
159+
cluster_id="j-1YONHTCP3YZKC",
160+
display_name="emr_step_1",
161+
description="EMR Step with retry policies",
162+
step_config=emr_step_config,
163+
retry_policies=retry_policies,
164+
)
165+
166+
# Step with cluster config and retry policies
167+
cluster_config = {
168+
"Instances": {
169+
"InstanceGroups": [
170+
{
171+
"Name": "Master Instance Group",
172+
"InstanceRole": "MASTER",
173+
"InstanceCount": 1,
174+
"InstanceType": "m1.small",
175+
"Market": "ON_DEMAND",
176+
}
177+
],
178+
"InstanceCount": 1,
179+
"HadoopVersion": "MyHadoopVersion",
180+
},
181+
"AmiVersion": "3.8.0",
182+
"AdditionalInfo": "MyAdditionalInfo",
183+
}
184+
185+
step_emr_2 = EMRStep(
186+
name="emr-step-2",
187+
display_name="emr_step_2",
188+
description="EMR Step with cluster config and retry policies",
189+
cluster_id=None,
190+
step_config=emr_step_config,
191+
cluster_config=cluster_config,
192+
retry_policies=retry_policies,
193+
)
194+
195+
pipeline = Pipeline(
196+
name=pipeline_name,
197+
steps=[step_emr_1, step_emr_2],
198+
sagemaker_session=sagemaker_session,
199+
)
200+
201+
try:
202+
response = pipeline.create(role)
203+
create_arn = response["PipelineArn"]
204+
assert re.match(
205+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
206+
create_arn,
207+
)
208+
finally:
209+
try:
210+
pipeline.delete()
211+
except Exception:
212+
pass
213+
214+
215+
def test_emr_with_expire_after_retry_policy(sagemaker_session, role, pipeline_name, region_name):
216+
"""Test EMR step with retry policy using expire_after_mins."""
217+
emr_step_config = EMRStepConfig(
218+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
219+
args=["dummy_emr_script_path"],
220+
)
221+
222+
retry_policies = [
223+
StepRetryPolicy(
224+
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
225+
interval_seconds=1,
226+
expire_after_mins=30,
227+
backoff_rate=2.0,
228+
)
229+
]
230+
231+
step_emr = EMRStep(
232+
name="emr-step-expire",
233+
cluster_id="j-1YONHTCP3YZKC",
234+
display_name="emr_step_expire",
235+
description="EMR Step with expire after retry policy",
236+
step_config=emr_step_config,
237+
retry_policies=retry_policies,
238+
)
239+
240+
pipeline = Pipeline(
241+
name=pipeline_name,
242+
steps=[step_emr],
243+
sagemaker_session=sagemaker_session,
244+
)
245+
246+
try:
247+
response = pipeline.create(role)
248+
create_arn = response["PipelineArn"]
249+
assert re.match(
250+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
251+
create_arn,
252+
)
253+
finally:
254+
try:
255+
pipeline.delete()
256+
except Exception:
257+
pass
258+
259+
260+
def test_emr_with_multiple_exception_types(sagemaker_session, role, pipeline_name, region_name):
261+
"""Test EMR step with multiple exception types in retry policy."""
262+
retry_policies = [
263+
StepRetryPolicy(
264+
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING],
265+
interval_seconds=1,
266+
max_attempts=3,
267+
backoff_rate=2.0,
268+
)
269+
]
270+
271+
step_emr = EMRStep(
272+
name="emr-step-multi-except",
273+
cluster_id="j-1YONHTCP3YZKC",
274+
display_name="emr_step_multi_except",
275+
description="EMR Step with multiple exception types",
276+
step_config=EMRStepConfig(
277+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
278+
args=["dummy_emr_script_path"],
279+
),
280+
retry_policies=retry_policies,
281+
)
282+
283+
pipeline = Pipeline(
284+
name=pipeline_name,
285+
steps=[step_emr],
286+
sagemaker_session=sagemaker_session,
287+
)
288+
289+
try:
290+
response = pipeline.create(role)
291+
create_arn = response["PipelineArn"]
292+
assert re.match(
293+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
294+
create_arn,
295+
)
296+
finally:
297+
try:
298+
pipeline.delete()
299+
except Exception:
300+
pass
301+
302+
303+
def test_emr_with_multiple_retry_policies(sagemaker_session, role, pipeline_name, region_name):
304+
"""Test EMR step with multiple retry policies."""
305+
retry_policies = [
306+
StepRetryPolicy(
307+
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
308+
interval_seconds=1,
309+
max_attempts=3,
310+
backoff_rate=2.0,
311+
),
312+
StepRetryPolicy(
313+
exception_types=[StepExceptionTypeEnum.THROTTLING],
314+
interval_seconds=5,
315+
expire_after_mins=60,
316+
backoff_rate=1.5,
317+
),
318+
]
319+
320+
step_emr = EMRStep(
321+
name="emr-step-multi-policy",
322+
cluster_id="j-1YONHTCP3YZKC",
323+
display_name="emr_step_multi_policy",
324+
description="EMR Step with multiple retry policies",
325+
step_config=EMRStepConfig(
326+
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
327+
args=["dummy_emr_script_path"],
328+
),
329+
retry_policies=retry_policies,
330+
)
331+
332+
pipeline = Pipeline(
333+
name=pipeline_name,
334+
steps=[step_emr],
335+
sagemaker_session=sagemaker_session,
336+
)
337+
338+
try:
339+
response = pipeline.create(role)
340+
create_arn = response["PipelineArn"]
341+
assert re.match(
342+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
343+
create_arn,
344+
)
345+
finally:
346+
try:
347+
pipeline.delete()
348+
except Exception:
349+
pass

tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,35 @@ def test_get_args_from_nova_recipe_with_distillation_errors(test_case):
446446
_get_args_from_nova_recipe(
447447
recipe=recipe, compute=test_case["compute"], role=test_case.get("role")
448448
)
449+
450+
451+
@pytest.mark.parametrize(
452+
"test_case",
453+
[
454+
{
455+
"recipe": {
456+
"evaluation": {"task:": "gen_qa", "strategy": "gen_qa", "metric": "all"},
457+
"processor": {
458+
"lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction"
459+
},
460+
},
461+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
462+
"role": "arn:aws:iam::123456789012:role/SageMakerRole",
463+
"expected_args": {
464+
"compute": Compute(instance_type="ml.m5.xlarge", instance_count=2),
465+
"hyperparameters": {
466+
"eval_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyLambdaFunction",
467+
},
468+
"training_image": None,
469+
"source_code": None,
470+
"distributed": None,
471+
},
472+
},
473+
],
474+
)
475+
def test_get_args_from_nova_recipe_with_evaluation(test_case):
476+
recipe = OmegaConf.create(test_case["recipe"])
477+
args, _ = _get_args_from_nova_recipe(
478+
recipe=recipe, compute=test_case["compute"], role=test_case["role"]
479+
)
480+
assert args == test_case["expected_args"]

0 commit comments

Comments
 (0)