Skip to content

Commit 0d78708

Browse files
authored
Add search classmethod for Record (#68)
* Add search classmethod for Record * limit search result to prevent throttling
1 parent b6a0953 commit 0d78708

File tree

11 files changed

+371
-0
lines changed

11 files changed

+371
-0
lines changed

src/smexperiments/_base_types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,35 @@ def _list(
131131
except StopIteration:
132132
return
133133

134+
@classmethod
135+
def _search(
136+
cls,
137+
search_resource,
138+
search_item_factory,
139+
boto_next_token_name="NextToken",
140+
sagemaker_boto_client=None,
141+
**kwargs
142+
):
143+
sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client()
144+
next_token = None
145+
try:
146+
while True:
147+
search_request_kwargs = _boto_functions.to_boto(kwargs, cls._custom_boto_names, cls._custom_boto_types)
148+
search_request_kwargs["Resource"] = search_resource
149+
if next_token:
150+
search_request_kwargs[boto_next_token_name] = next_token
151+
search_method = getattr(sagemaker_boto_client, "search")
152+
search_method_response = search_method(**search_request_kwargs)
153+
search_items = search_method_response.get("Results", [])
154+
next_token = search_method_response.get(boto_next_token_name)
155+
for item in search_items:
156+
if cls.__name__ in item:
157+
yield search_item_factory(item[cls.__name__])
158+
if not next_token:
159+
break
160+
except StopIteration:
161+
return
162+
134163
@classmethod
135164
def _construct(cls, boto_method_name, sagemaker_boto_client=None, **kwargs):
136165
sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client()

src/smexperiments/api_types.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,157 @@ class BatchPutMetricsError(_base_types.ApiObject):
270270

271271
def __init__(self, code=None, message=None, metric_index=None, **kwargs):
272272
super(BatchPutMetricsError, self).__init__(code=code, message=message, metric_index=metric_index, **kwargs)
273+
274+
275+
class ExperimentSearchResult(_base_types.ApiObject):
276+
"""Summary model of an Experiment search result.
277+
278+
Attributes:
279+
experiment_arn (str): Arn of the experiment.
280+
experiment_name (str): Name of the experiment.
281+
display_name (str): Display name of the experiment.
282+
source (dict): The source of the experiment
283+
tags (list): The list of tags that are associated with the experiment.
284+
"""
285+
286+
experiment_arn = None
287+
experiment_name = None
288+
display_name = None
289+
source = None
290+
tags = None
291+
292+
def __init__(self, experiment_arn=None, experiment_name=None, display_name=None, source=None, tags=None, **kwargs):
293+
super(ExperimentSearchResult, self).__init__(
294+
experiment_arn=experiment_arn,
295+
experiment_name=experiment_name,
296+
display_name=display_name,
297+
source=source,
298+
tags=tags,
299+
**kwargs
300+
)
301+
302+
303+
class TrialSearchResult(_base_types.ApiObject):
304+
"""Summary model of an Trial search result.
305+
306+
Attributes:
307+
trial_arn (str): Arn of the trial.
308+
trial_name (str): Name of the trial.
309+
display_name (str): Display name of the trial.
310+
source (dict): The source of the trial.
311+
tags (list): The list of tags that are associated with the trial.
312+
trial_component_summaries (dict):
313+
"""
314+
315+
trial_arn = None
316+
trial_name = None
317+
display_name = None
318+
source = None
319+
tags = None
320+
trial_component_summaries = None
321+
322+
def __init__(
323+
self,
324+
trial_arn=None,
325+
trial_name=None,
326+
display_name=None,
327+
source=None,
328+
tags=None,
329+
trial_component_summaries=None,
330+
**kwargs
331+
):
332+
super(TrialSearchResult, self).__init__(
333+
trial_arn=trial_arn,
334+
trial_name=trial_name,
335+
display_name=display_name,
336+
source=source,
337+
tags=tags,
338+
trial_component_summaries=trial_component_summaries,
339+
**kwargs
340+
)
341+
342+
343+
class TrialComponentSearchResult(_base_types.ApiObject):
344+
"""Summary model of an Trial Component search result.
345+
346+
Attributes:
347+
trial_component_arn (str): Arn of the trial component.
348+
trial_component_name (str): Name of the trial component.
349+
display_name (str): Display name of the trial component.
350+
source (dict): The source of the trial component.
351+
status (dict): The status of the trial component.
352+
start_time (datetime): Start time.
353+
end_time (datetime): End time.
354+
creation_time (datetime): Creation time.
355+
created_by (str): Created by.
356+
last_modified_time (datetime): Date last modified.
357+
last_modified_by (datetime): User last modified.
358+
parameters (dict): The hyperparameters of the component.
359+
input_artifacts (dict): The input artifacts of the component.
360+
output_artifacts (dict): The output artifacts of the component.
361+
metrics (list): The metrics for the component.
362+
source_detail (dict): The source of the trial component.
363+
tags (list): The list of tags that are associated with the trial component.
364+
parents (dict): The parent of trial component
365+
"""
366+
367+
trial_component_arn = None
368+
trial_component_name = None
369+
display_name = None
370+
source = None
371+
status = None
372+
start_time = None
373+
end_time = None
374+
creation_time = None
375+
created_by = None
376+
last_modified_time = None
377+
last_modified_by = None
378+
parameters = None
379+
input_artifacts = None
380+
output_artifacts = None
381+
metrics = None
382+
source_detail = None
383+
tags = None
384+
parents = None
385+
386+
def __init__(
387+
self,
388+
trial_component_arn=None,
389+
trial_component_name=None,
390+
start_time=None,
391+
end_time=None,
392+
display_name=None,
393+
source=None,
394+
status=None,
395+
creation_time=None,
396+
created_by=None,
397+
last_modified_time=None,
398+
last_modified_by=None,
399+
parameters=None,
400+
input_artifacts=None,
401+
output_artifacts=None,
402+
metrics=None,
403+
source_detail=None,
404+
tags=None,
405+
parents=None,
406+
):
407+
super(TrialComponentSearchResult, self).__init__(
408+
trial_component_arn=trial_component_arn,
409+
trial_component_name=trial_component_name,
410+
display_name=display_name,
411+
source=source,
412+
status=status,
413+
start_time=start_time,
414+
end_time=end_time,
415+
creation_time=creation_time,
416+
created_by=created_by,
417+
last_modified_by=last_modified_by,
418+
last_modified_time=last_modified_time,
419+
parameters=parameters,
420+
input_artifacts=input_artifacts,
421+
output_artifacts=output_artifacts,
422+
metrics=metrics,
423+
source_detail=source_detail,
424+
tags=tags,
425+
parents=parents,
426+
)

src/smexperiments/experiment.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,38 @@ def list(
145145
sagemaker_boto_client=sagemaker_boto_client,
146146
)
147147

148+
@classmethod
149+
def search(
150+
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
151+
):
152+
"""
153+
Search experiments. Returns SearchResults in the account matching the search criteria.
154+
155+
Args:
156+
search_expression: (dict, optional): A Boolean conditional statement. Resource objects
157+
must satisfy this condition to be included in search results. You must provide at
158+
least one subexpression, filter, or nested filter.
159+
sort_by (str, optional): The name of the resource property used to sort the SearchResults.
160+
The default is LastModifiedTime
161+
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or
162+
Descending . The default is Descending .
163+
max_results (int, optional): The maximum number of results to return in a SearchResponse.
164+
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
165+
supplied, a default boto3 client will be used.
166+
167+
Returns:
168+
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria.
169+
"""
170+
return super(Experiment, cls)._search(
171+
search_resource="Experiment",
172+
search_item_factory=api_types.ExperimentSearchResult.from_boto,
173+
search_expression=search_expression,
174+
sort_by=sort_by,
175+
sort_order=sort_order,
176+
max_results=max_results,
177+
sagemaker_boto_client=sagemaker_boto_client,
178+
)
179+
148180
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):
149181
"""List trials in this experiment matching the specified criteria.
150182

src/smexperiments/trial.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,38 @@ def list(
164164
sagemaker_boto_client=sagemaker_boto_client,
165165
)
166166

167+
@classmethod
168+
def search(
169+
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
170+
):
171+
"""
172+
Search experiments. Returns SearchResults in the account matching the search criteria.
173+
174+
Args:
175+
search_expression: (dict, optional): A Boolean conditional statement. Resource objects
176+
must satisfy this condition to be included in search results. You must provide at
177+
least one subexpression, filter, or nested filter.
178+
sort_by (str, optional): The name of the resource property used to sort the SearchResults.
179+
The default is LastModifiedTime
180+
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or
181+
Descending . The default is Descending .
182+
max_results (int, optional): The maximum number of results to return in a SearchResponse.
183+
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
184+
supplied, a default boto3 client will be used.
185+
186+
Returns:
187+
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria.
188+
"""
189+
return super(Trial, cls)._search(
190+
search_resource="ExperimentTrial",
191+
search_item_factory=api_types.TrialSearchResult.from_boto,
192+
search_expression=search_expression,
193+
sort_by=sort_by,
194+
sort_order=sort_order,
195+
max_results=max_results,
196+
sagemaker_boto_client=sagemaker_boto_client,
197+
)
198+
167199
def add_trial_component(self, tc):
168200
"""Add the specified trial component to this ``Trial``.
169201

src/smexperiments/trial_component.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,35 @@ def list(
179179
max_results=max_results,
180180
next_token=next_token,
181181
)
182+
183+
@classmethod
184+
def search(
185+
cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_boto_client=None,
186+
):
187+
"""
188+
Search experiments. Returns SearchResults in the account matching the search criteria.
189+
190+
Args:
191+
search_expression: (dict, optional): A Boolean conditional statement. Resource objects
192+
must satisfy this condition to be included in search results. You must provide at
193+
least one subexpression, filter, or nested filter.
194+
sort_by (str, optional): The name of the resource property used to sort the SearchResults.
195+
The default is LastModifiedTime
196+
sort_order (str, optional): How SearchResults are ordered. Valid values are Ascending or
197+
Descending . The default is Descending .
198+
max_results (int, optional): The maximum number of results to return in a SearchResponse.
199+
sagemaker_boto_client (SageMaker.Client, optional): Boto3 client for SageMaker. If not
200+
supplied, a default boto3 client will be used.
201+
202+
Returns:
203+
collections.Iterator[SearchResult] : An iterator over search results matching the search criteria.
204+
"""
205+
return super(TrialComponent, cls)._search(
206+
search_resource="ExperimentTrialComponent",
207+
search_item_factory=api_types.TrialComponentSearchResult.from_boto,
208+
search_expression=search_expression,
209+
sort_by=sort_by,
210+
sort_order=sort_order,
211+
max_results=max_results,
212+
sagemaker_boto_client=sagemaker_boto_client,
213+
)

tests/integ/test_experiment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ def test_list_sort(sagemaker_boto_client, experiments):
8585
assert experiment_names # sanity test
8686

8787

88+
def test_search(sagemaker_boto_client):
89+
experiment_names_searched = []
90+
for s in experiment.Experiment.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
91+
if "smexperiments-integ-" in s.experiment_name:
92+
experiment_names_searched.append(s.experiment_name)
93+
94+
assert len(experiment_names_searched) > 0
95+
assert experiment_names_searched # sanity test
96+
97+
8898
def test_create_trial(experiment_obj, sagemaker_boto_client):
8999
trial_obj = experiment_obj.create_trial()
90100
try:

tests/integ/test_trial.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def test_list_sort(trials, sagemaker_boto_client):
7676
assert trial_names_listed # sanity test
7777

7878

79+
def test_search(sagemaker_boto_client):
80+
trial_names_searched = []
81+
for s in trial.Trial.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
82+
if "smexperiments-integ-" in s.trial_name:
83+
trial_names_searched.append(s.trial_name)
84+
85+
assert len(trial_names_searched) > 0
86+
assert trial_names_searched # sanity test
87+
88+
7989
def test_add_remove_trial_component(trial_obj, trial_component_obj):
8090
trial_obj.add_trial_component(trial_component_obj)
8191
trial_components = list(trial_obj.list_trial_components())

tests/integ/test_trial_component.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,13 @@ def test_list_trial_components_by_experiment(experiment_obj, trial_component_obj
9999
)
100100
assert 0 == len(trial_components)
101101
trial_obj.delete()
102+
103+
104+
def test_search(sagemaker_boto_client):
105+
trial_component_names_searched = []
106+
for s in trial_component.TrialComponent.search(max_results=10, sagemaker_boto_client=sagemaker_boto_client):
107+
if "smexperiments-integ-" in s.trial_component_name:
108+
trial_component_names_searched.append(s.trial_component_name)
109+
110+
assert len(trial_component_names_searched) > 0
111+
assert trial_component_names_searched # sanity test

tests/unit/test_experiment.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,36 @@ def test_list_trials_call_args(sagemaker_boto_client):
129129
sagemaker_boto_client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after)
130130

131131

132+
def test_search(sagemaker_boto_client):
133+
sagemaker_boto_client.search.return_value = {
134+
"Results": [
135+
{
136+
"Experiment": {
137+
"ExperimentName": "experiment-1",
138+
"ExperimentArn": "arn::experiment-1",
139+
"DisplayName": "Experiment1",
140+
}
141+
},
142+
{
143+
"Experiment": {
144+
"ExperimentName": "experiment-2",
145+
"ExperimentArn": "arn::experiment-2",
146+
"DisplayName": "Experiment2",
147+
}
148+
},
149+
]
150+
}
151+
expected = [
152+
api_types.ExperimentSearchResult(
153+
experiment_name="experiment-1", experiment_arn="arn::experiment-1", display_name="Experiment1"
154+
),
155+
api_types.ExperimentSearchResult(
156+
experiment_name="experiment-2", experiment_arn="arn::experiment-2", display_name="Experiment2",
157+
),
158+
]
159+
assert expected == list(experiment.Experiment.search(sagemaker_boto_client=sagemaker_boto_client))
160+
161+
132162
def test_experiment_create_trial_with_name(sagemaker_boto_client):
133163
experiment_obj = experiment.Experiment(sagemaker_boto_client=sagemaker_boto_client)
134164
experiment_obj.experiment_name = "someExperimentName"

0 commit comments

Comments
 (0)