Skip to content

Commit 85375a9

Browse files
authored
Add search for training job and fix slow tests (#71)
* Add search for training job and fix slow tests * fix black format * Improve tests * fix black format
1 parent a121057 commit 85375a9

File tree

7 files changed

+234
-25
lines changed

7 files changed

+234
-25
lines changed

src/smexperiments/api_types.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,85 @@ def __init__(self, code=None, message=None, metric_index=None, **kwargs):
272272
super(BatchPutMetricsError, self).__init__(code=code, message=message, metric_index=metric_index, **kwargs)
273273

274274

275+
class TrainingJobSearchResult(_base_types.ApiObject):
276+
"""Summary model of an Training Job search result.
277+
278+
Attributes:
279+
training_job_name (str): The name of the training job.
280+
training_job_arn (str): The Amazon Resource Name (ARN) of the training job.
281+
tuning_job_arn (str): The Amazon Resource Name (ARN) of the associated.
282+
hyperparameter tuning job if the training job was launched by a hyperparameter tuning job.
283+
labeling_job_arn (str): The Amazon Resource Name (ARN) of the labeling job.
284+
autoML_job_arn (str): The Amazon Resource Name (ARN) of the job.
285+
model_artifacts (dict): Information about the Amazon S3 location that is configured for storing model artifacts.
286+
training_job_status (str): The status of the training job
287+
hyper_parameters (dict): Algorithm-specific parameters.
288+
algorithm_specification (dict): Information about the algorithm used for training, and algorithm metadata.
289+
input_data_config (dict): An array of Channel objects that describes each data input channel.
290+
output_data_config (dict): The S3 path where model artifacts that you configured when creating the job are
291+
stored. Amazon SageMaker creates subfolders for model artifacts.
292+
resource_config (dict): Resources, including ML compute instances and ML storage volumes, that are configured
293+
for model training.
294+
debug_hook_config (dict): Configuration information for the debug hook parameters, collection configuration,
295+
and storage paths.
296+
debug_rule_config (dict): Information about the debug rule configuration.
297+
"""
298+
299+
training_job_name = None
300+
training_job_arn = None
301+
tuning_job_arn = None
302+
labeling_job_arn = None
303+
autoML_job_arn = None
304+
model_artifacts = None
305+
training_job_status = None
306+
hyper_parameters = None
307+
algorithm_specification = None
308+
input_data_config = None
309+
output_data_config = None
310+
resource_config = None
311+
debug_hook_config = None
312+
experiment_config = None
313+
debug_rule_config = None
314+
315+
def __init__(
316+
self,
317+
training_job_arn=None,
318+
training_job_name=None,
319+
tuning_job_arn=None,
320+
labeling_job_arn=None,
321+
autoML_job_arn=None,
322+
model_artifacts=None,
323+
training_job_status=None,
324+
hyper_parameters=None,
325+
algorithm_specification=None,
326+
input_data_config=None,
327+
output_data_config=None,
328+
resource_config=None,
329+
debug_hook_config=None,
330+
experiment_config=None,
331+
debug_rule_config=None,
332+
**kwargs
333+
):
334+
super(TrainingJobSearchResult, self).__init__(
335+
training_job_arn=training_job_arn,
336+
training_job_name=training_job_name,
337+
tuning_job_arn=tuning_job_arn,
338+
labeling_job_arn=labeling_job_arn,
339+
autoML_job_arn=autoML_job_arn,
340+
model_artifacts=model_artifacts,
341+
training_job_status=training_job_status,
342+
hyper_parameters=hyper_parameters,
343+
algorithm_specification=algorithm_specification,
344+
input_data_config=input_data_config,
345+
output_data_config=output_data_config,
346+
resource_config=resource_config,
347+
debug_hook_config=debug_hook_config,
348+
experiment_config=experiment_config,
349+
debug_rule_config=debug_rule_config,
350+
**kwargs
351+
)
352+
353+
275354
class ExperimentSearchResult(_base_types.ApiObject):
276355
"""Summary model of an Experiment search result.
277356
@@ -369,12 +448,6 @@ class TrialComponentSearchResult(_base_types.ApiObject):
369448
display_name = None
370449
source = None
371450
status = None
372-
start_time = None
373-
end_time = None
374-
creation_time = None
375-
created_by = None
376-
last_modified_time = None
377-
last_modified_by = None
378451
parameters = None
379452
input_artifacts = None
380453
output_artifacts = None
@@ -392,35 +465,27 @@ def __init__(
392465
display_name=None,
393466
source=None,
394467
status=None,
395-
creation_time=None,
396-
created_by=None,
397-
last_modified_time=None,
398-
last_modified_by=None,
399468
parameters=None,
400469
input_artifacts=None,
401470
output_artifacts=None,
402471
metrics=None,
403472
source_detail=None,
404473
tags=None,
405474
parents=None,
475+
**kwargs
406476
):
407477
super(TrialComponentSearchResult, self).__init__(
408478
trial_component_arn=trial_component_arn,
409479
trial_component_name=trial_component_name,
410480
display_name=display_name,
411481
source=source,
412482
status=status,
413-
start_time=start_time,
414-
end_time=end_time,
415-
creation_time=creation_time,
416-
created_by=created_by,
417-
last_modified_by=last_modified_by,
418-
last_modified_time=last_modified_time,
419483
parameters=parameters,
420484
input_artifacts=input_artifacts,
421485
output_artifacts=output_artifacts,
422486
metrics=metrics,
423487
source_detail=source_detail,
424488
tags=tags,
425489
parents=parents,
490+
**kwargs
426491
)

src/smexperiments/training_job.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""Contains the SageMaker Training Job class."""
15+
from smexperiments import _base_types, api_types
16+
17+
18+
class TrainingJob(_base_types.Record):
19+
@classmethod
20+
def search(
21+
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
22+
):
23+
"""
24+
Search Training Job. Returns SearchResults in the account matching the search criteria.
25+
26+
Args:
27+
search_expression: (dict, optional): A Boolean conditional statement. Resource objects
28+
must satisfy this condition to be included in search results. You must provide at
29+
least one subexpression, filter, or nested filter.
30+
sort_by (str, optional): The name of the resource property used to sort the SearchResults.
31+
The default is LastModifiedTime
32+
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or
33+
Descending . The default is Descending .
34+
max_results (int, optional): The maximum number of results to return in a SearchResponse.
35+
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
36+
supplied, a default boto3 client will be used.
37+
38+
Returns:
39+
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria.
40+
"""
41+
return super(TrainingJob, cls)._search(
42+
search_resource="TrainingJob",
43+
search_item_factory=api_types.TrainingJobSearchResult.from_boto,
44+
search_expression=None if search_expression is None else search_expression.to_boto(),
45+
sort_by=sort_by,
46+
sort_order=sort_order,
47+
max_results=max_results,
48+
sagemaker_boto_client=sagemaker_boto_client,
49+
)

tests/conftest.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,7 @@ def training_job_name(sagemaker_boto_client, training_role_arn, docker_image, tr
264264
"DataSource": {"S3DataSource": {"S3Uri": training_s3_uri, "S3DataType": "S3Prefix"}},
265265
}
266266
],
267-
AlgorithmSpecification={
268-
"TrainingImage": docker_image,
269-
"TrainingInputMode": "File",
270-
"EnableSageMakerMetricsTimeSeries": True,
271-
},
267+
AlgorithmSpecification={"TrainingImage": docker_image, "TrainingInputMode": "File",},
272268
RoleArn=training_role_arn,
273269
ResourceConfig={"InstanceType": "ml.m5.large", "InstanceCount": 1, "VolumeSizeInGB": 10},
274270
StoppingCondition={"MaxRuntimeInSeconds": 900},

tests/integ/test_track_from_processing_job.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020

2121
@pytest.mark.slow
2222
def test_track_from_processing_job(sagemaker_boto_client, processing_job_name):
23-
2423
get_job = lambda: sagemaker_boto_client.describe_processing_job(ProcessingJobName=processing_job_name)
2524
processing_job = get_job()
2625

2726
source_arn = processing_job["ProcessingJobArn"]
28-
# wait_for_job(processing_job_name, get_job, "ProcessingJobStatus")
27+
wait_for_job(processing_job_name, get_job, "ProcessingJobStatus")
2928

3029
print(processing_job)
3130
if "ProcessingStartTime" in processing_job:

tests/integ/test_track_from_training_job.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -21,7 +21,6 @@
2121

2222
@pytest.mark.slow
2323
def test_track_from_training_job(sagemaker_boto_client, training_job_name):
24-
training_job_name = "smexperiments-integ-eca5c064-3a64-433e-a30a-2963338d71d8"
2524
get_job = lambda: sagemaker_boto_client.describe_training_job(TrainingJobName=training_job_name)
2625
tj = get_job()
2726
source_arn = tj["TrainingJobArn"]

tests/integ/test_training_job.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
15+
from smexperiments.training_job import TrainingJob
16+
from smexperiments.search_expression import SearchExpression, Filter, Operator
17+
from tests.helpers import retry
18+
19+
20+
@pytest.mark.slow
21+
def test_search(sagemaker_boto_client, training_job_name, docker_image):
22+
def validate():
23+
training_job_searched = []
24+
search_filter = Filter(name="TrainingJobName", operator=Operator.EQUALS, value=training_job_name)
25+
search_expression = SearchExpression(filters=[search_filter])
26+
for s in TrainingJob.search(
27+
search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client,
28+
):
29+
training_job_searched.append(s)
30+
31+
assert len(training_job_searched) == 1
32+
assert training_job_searched[0].training_job_name == training_job_name
33+
assert training_job_searched[0].input_data_config[0]["ChannelName"] == "train"
34+
assert training_job_searched[0].algorithm_specification == {
35+
"TrainingImage": docker_image,
36+
"TrainingInputMode": "File",
37+
}
38+
assert training_job_searched[0].resource_config == {
39+
"InstanceType": "ml.m5.large",
40+
"InstanceCount": 1,
41+
"VolumeSizeInGB": 10,
42+
}
43+
assert training_job_searched[0].stopping_condition == {"MaxRuntimeInSeconds": 900}
44+
assert training_job_searched # sanity test
45+
46+
retry(validate)

tests/unit/test_training_job.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
import unittest.mock
15+
16+
from smexperiments import training_job, api_types
17+
18+
19+
@pytest.fixture
20+
def sagemaker_boto_client():
21+
return unittest.mock.Mock()
22+
23+
24+
def test_search(sagemaker_boto_client):
25+
sagemaker_boto_client.search.return_value = {
26+
"Results": [
27+
{
28+
"TrainingJob": {
29+
"TrainingJobName": "training-1",
30+
"TrainingJobArn": "arn::training-1",
31+
"HyperParameters": {"learning_rate": "0.1"},
32+
}
33+
},
34+
{
35+
"TrainingJob": {
36+
"TrainingJobName": "training-2",
37+
"TrainingJobArn": "arn::training-2",
38+
"HyperParameters": {"learning_rate": "0.2"},
39+
}
40+
},
41+
]
42+
}
43+
expected = [
44+
api_types.TrainingJobSearchResult(
45+
training_job_name="training-1",
46+
training_job_arn="arn::training-1",
47+
hyper_parameters={"learning_rate": "0.1"},
48+
),
49+
api_types.TrainingJobSearchResult(
50+
training_job_name="training-2",
51+
training_job_arn="arn::training-2",
52+
hyper_parameters={"learning_rate": "0.2"},
53+
),
54+
]
55+
assert expected == list(training_job.TrainingJob.search(sagemaker_boto_client=sagemaker_boto_client))

0 commit comments

Comments
 (0)