Skip to content

Commit b6a0953

Browse files
authored
Support ListTrials by trial component name (#67)
* Add SageMaker analytics example * Support ListTrials by trial component name * fix flake8 format * fix black format * Update min boto version required * Update boto3 version in tox.ini
1 parent a438e68 commit b6a0953

File tree

5 files changed

+41
-3
lines changed

5 files changed

+41
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def read(fname):
2424

2525

2626
# Declare minimal set for installation
27-
required_packages = ["boto3>=1.10.32"]
27+
required_packages = ["boto3>=1.12.8"]
2828

2929
# Open readme with original (i.e. LF) newlines
3030
# to prevent the all too common "`long_description_content_type` missing"

src/smexperiments/trial.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
125125
def list(
126126
cls,
127127
experiment_name=None,
128+
trial_component_name=None,
128129
created_before=None,
129130
created_after=None,
130131
sort_by=None,
@@ -136,6 +137,8 @@ def list(
136137
Args:
137138
experiment_name (str, optional): Name of the experiment. If specified, only trials in
138139
the experiment will be returned.
140+
trial_component_name (str, optional): Name of the trial component. If specified, only
141+
trials with this trial component name will be returned.
139142
created_before (datetime.datetime, optional): Return trials created before this instant.
140143
created_after (datetime.datetime, optional): Return trials created after this instant.
141144
sort_by (str, optional): Which property to sort results by. One of 'Name',
@@ -153,6 +156,7 @@ def list(
153156
api_types.TrialSummary.from_boto,
154157
"TrialSummaries",
155158
experiment_name=experiment_name,
159+
trial_component_name=trial_component_name,
156160
created_before=created_before,
157161
created_after=created_after,
158162
sort_by=sort_by,

tests/integ/test_trial.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ def test_list(trials, sagemaker_boto_client):
3535
assert trial_names_listed # sanity test
3636

3737

38+
def test_list_with_trial_component(trials, trial_component_obj, sagemaker_boto_client):
39+
trial_with_component = trials[0]
40+
trial_with_component.add_trial_component(trial_component_obj)
41+
42+
trial_listed = [
43+
s.trial_name
44+
for s in trial.Trial.list(
45+
trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client
46+
)
47+
]
48+
assert len(trial_listed) == 1
49+
assert trial_with_component.trial_name == trial_listed[0]
50+
# clean up
51+
trial_with_component.remove_trial_component(trial_component_obj)
52+
assert trial_listed
53+
54+
3855
def test_list_sort(trials, sagemaker_boto_client):
3956
slack = datetime.timedelta(minutes=1)
4057
now = datetime.datetime.now(datetime.timezone.utc)

tests/unit/test_trial.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def test_list_trials_with_experiment_name(sagemaker_boto_client, datetime_obj):
123123
sagemaker_boto_client.list_trials.assert_called_with(ExperimentName="foo")
124124

125125

126+
def test_list_trials_with_trial_component_name(sagemaker_boto_client, datetime_obj):
127+
sagemaker_boto_client.list_trials.return_value = {
128+
"TrialSummaries": [
129+
{"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
130+
{"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
131+
]
132+
}
133+
expected = [
134+
api_types.TrialSummary(trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj),
135+
api_types.TrialSummary(trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj),
136+
]
137+
assert expected == list(
138+
trial.Trial.list(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
139+
)
140+
sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="tc-foo")
141+
142+
126143
def test_delete(sagemaker_boto_client):
127144
obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
128145
sagemaker_boto_client.delete_trial.return_value = {}

tox.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ commands =
5353
{env:IGNORE_COVERAGE:} coverage report --fail-under=95
5454
extras = test
5555
deps =
56-
boto3 >= 1.10.32
56+
boto3 >= 1.12.8
5757
python-dateutil
5858
pytest
5959
pytest-cov
@@ -98,7 +98,7 @@ commands =
9898
pytest {posargs} --verbose --runslow --capture=no
9999
extras = test
100100
deps =
101-
boto3 >= 1.10.32
101+
boto3 >= 1.12.8
102102
pytest
103103
docker
104104

0 commit comments

Comments
 (0)