Skip to content

Commit 6341eb7

Browse files
committed
change: don't require instance_type for image_uris.retrieve() if only one option
1 parent abd873e commit 6341eb7

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

src/sagemaker/image_uris.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def _processor(instance_type, available_processors):
173173
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
174174
return None
175175

176+
if len(available_processors) == 1 and not instance_type:
177+
logger.info("Defaulting to only supported image scope: %s.", available_processors[0])
178+
return available_processors[0]
179+
176180
if not instance_type:
177181
raise ValueError(
178182
"Empty SageMaker instance type. For options, see: "

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,22 @@ def test_retrieve_processor_type_from_version_specific_processor_config(config_f
522522
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri
523523

524524

525+
@patch("sagemaker.image_uris.config_for_framework")
526+
def test_retrieve_default_processor_type_if_possible(config_for_framework):
527+
config = copy.deepcopy(BASE_CONFIG)
528+
config["processors"] = ["cpu"]
529+
config_for_framework.return_value = config
530+
531+
uri = image_uris.retrieve(
532+
framework="useless-string",
533+
version="1.0.0",
534+
py_version="py3",
535+
region="us-west-2",
536+
image_scope="training",
537+
)
538+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
539+
540+
525541
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
526542
def test_retrieve_unsupported_processor_type(config_for_framework):
527543
with pytest.raises(ValueError) as e:

tests/unit/sagemaker/image_uris/test_xgboost.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@
7272
def test_xgboost_framework(xgboost_framework_version):
7373
for region in regions.regions():
7474
uri = image_uris.retrieve(
75-
framework="xgboost",
76-
region=region,
77-
version=xgboost_framework_version,
78-
py_version="py3",
79-
instance_type="ml.c4.xlarge",
75+
framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3",
8076
)
8177

8278
expected = expected_uris.framework_uri(

0 commit comments

Comments
 (0)