Skip to content

Commit a60b7b5

Browse files
committed
feature - create metrics writer in Tracker.create path
1 parent 000d787 commit a60b7b5

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/smexperiments/tracker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def load(
115115
else:
116116
raise ValueError('Could not load TrialComponent. Specify a trial_component_name or invoke "create"')
117117

118+
# if running in a SageMaker context write metrics to file
118119
if not trial_component_name and tce.environment_type == _environment.EnvironmentType.SageMakerTrainingJob:
119120
metrics_writer = metrics.SageMakerFileMetricsWriter()
120121
else:
@@ -160,8 +161,13 @@ def create(
160161
display_name=display_name,
161162
sagemaker_boto_client=sagemaker_boto_client,
162163
)
164+
165+
metrics_writer = metrics.SageMakerFileMetricsWriter()
166+
163167
return cls(
164-
tc, None, _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session)
168+
tc,
169+
metrics_writer,
170+
_ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session),
165171
)
166172

167173
def log_parameter(self, name, value):

tests/unit/test_tracker.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,23 @@ def test_load(boto3_session, sagemaker_boto_client):
117117
)
118118

119119

120+
def test_create(boto3_session, sagemaker_boto_client):
121+
trial_component_name = "foo-trial-component"
122+
trial_component_display_name = "foo-trial-component-display-name"
123+
sagemaker_boto_client.create_trial_component.return_value = {"TrialComponentName": trial_component_name}
124+
tracker_created = tracker.Tracker.create(
125+
display_name=trial_component_display_name, sagemaker_boto_client=sagemaker_boto_client
126+
)
127+
assert trial_component_name == tracker_created.trial_component.trial_component_name
128+
129+
assert tracker_created._metrics_writer is not None
130+
131+
tracker_created._metrics_writer = unittest.mock.Mock()
132+
now = datetime.datetime.now()
133+
tracker_created.log_metric("foo", 1.0, 1, now)
134+
tracker_created._metrics_writer.log_metric.assert_called_with("foo", 1.0, 1, now)
135+
136+
120137
@pytest.fixture
121138
def trial_component_obj(sagemaker_boto_client):
122139
return trial_component.TrialComponent(sagemaker_boto_client)

0 commit comments

Comments
 (0)