Skip to content

Commit 2f39d8e

Browse files
committed
fix: create trial with trial components
1 parent 85375a9 commit 2f39d8e

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/smexperiments/trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
118118
)
119119
if trial_components:
120120
for tc in trial_components:
121-
trial.add_trial_components(*trial_components)
121+
trial.add_trial_component(tc)
122122
return trial
123123

124124
@classmethod

tests/unit/test_trial.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,28 @@ def test_create_no_name(sagemaker_boto_client):
5757
assert kwargs["TrialName"] # confirm that a TrialName was passed
5858

5959

60+
def test_create_with_trial_components(sagemaker_boto_client):
61+
sagemaker_boto_client.create_trial.return_value = {
62+
"Arn": "arn:aws:1234",
63+
"TrialName": "name-value",
64+
}
65+
tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
66+
67+
trial_obj = trial.Trial.create(
68+
trial_name="name-value",
69+
experiment_name="experiment-name-value",
70+
trial_components=[tc],
71+
sagemaker_boto_client=sagemaker_boto_client,
72+
)
73+
assert trial_obj.trial_name == "name-value"
74+
sagemaker_boto_client.create_trial.assert_called_with(
75+
TrialName="name-value", ExperimentName="experiment-name-value"
76+
)
77+
sagemaker_boto_client.associate_trial_component.assert_called_with(
78+
TrialName="name-value", TrialComponentName=tc.trial_component_name
79+
)
80+
81+
6082
def test_add_trial_component(sagemaker_boto_client):
6183
t = trial.Trial(sagemaker_boto_client)
6284
t.trial_name = "bar"

0 commit comments

Comments
 (0)