Skip to content

Commit 743f733

Browse files
authored
Merge pull request #57 from owen-t/master
fix. Add default sagemaker_boto_client to list classmethods.
2 parents 5749437 + ef2af40 commit 743f733

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

src/smexperiments/_base_types.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,23 @@ def _list(
102102
sagemaker_boto_client=None,
103103
**kwargs
104104
):
105+
sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client()
105106
next_token = None
106-
while True:
107-
list_request_kwargs = _boto_functions.to_boto(kwargs, cls._custom_boto_names, cls._custom_boto_types)
108-
if next_token:
109-
list_request_kwargs[boto_next_token_name] = next_token
110-
list_method = getattr(sagemaker_boto_client, boto_list_method)
111-
list_method_response = list_method(**list_request_kwargs)
112-
list_items = list_method_response.get(boto_list_items_name, [])
113-
next_token = list_method_response.get(boto_next_token_name)
114-
for item in list_items:
115-
yield list_item_factory(item)
116-
if not next_token:
117-
break
107+
try:
108+
while True:
109+
list_request_kwargs = _boto_functions.to_boto(kwargs, cls._custom_boto_names, cls._custom_boto_types)
110+
if next_token:
111+
list_request_kwargs[boto_next_token_name] = next_token
112+
list_method = getattr(sagemaker_boto_client, boto_list_method)
113+
list_method_response = list_method(**list_request_kwargs)
114+
list_items = list_method_response.get(boto_list_items_name, [])
115+
next_token = list_method_response.get(boto_next_token_name)
116+
for item in list_items:
117+
yield list_item_factory(item)
118+
if not next_token:
119+
break
120+
except StopIteration:
121+
return
118122

119123
@classmethod
120124
def _construct(cls, boto_method_name, sagemaker_boto_client=None, **kwargs):

tests/unit/test_base_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,11 @@ def test_list_with_next_token(sagemaker_boto_client):
150150
"list", DummyRecordSummary.from_boto, "TestRecordSummaries", sagemaker_boto_client=sagemaker_boto_client,
151151
)
152152
)
153+
154+
155+
@unittest.mock.patch("smexperiments._base_types._utils.sagemaker_client")
156+
def test_list_no_client(mocked_utils_sagemaker_client, sagemaker_boto_client):
157+
mocked_utils_sagemaker_client.return_value = sagemaker_boto_client
158+
sagemaker_boto_client.list.side_effect = []
159+
list(DummyRecord._list("list", DummyRecordSummary.from_boto, "TestRecordSummaries"))
160+
assert _base_types._utils.sagemaker_client.called

0 commit comments

Comments
 (0)