Skip to content

Commit e1fe7f4

Browse files
author
Rui Wang Napieralski
committed
fix unit tests
1 parent 64c768d commit e1fe7f4

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/sagemaker/image_uri_config/neo-pytorch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"us-west-1": "710691900526",
3737
"us-west-2": "301217895009"
3838
},
39-
"repository": "sagemaker-neo-pytorch"
39+
"repository": "sagemaker-inference-pytorch"
4040
}
4141
}
4242
}

tests/unit/sagemaker/image_uris/test_neo.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,26 @@ def test_algo_uris(algo):
6262

6363

6464
def _test_neo_framework_uris(framework, version):
65-
framework = "neo-{}".format(framework)
65+
framework_in_config = f"neo-{framework}"
66+
framework_in_uri = f"neo-{framework}" if framework == "tensorflow" else f"inference-{framework}"
6667

6768
for region in regions.regions():
6869
if region in ACCOUNTS:
69-
uri = image_uris.retrieve(framework, region, instance_type="ml_c5", version=version)
70-
assert _expected_framework_uri(framework, version, region=region) == uri
70+
uri = image_uris.retrieve(
71+
framework_in_config, region, instance_type="ml_c5", version=version
72+
)
73+
assert _expected_framework_uri(framework_in_uri, version, region=region) == uri
7174
else:
7275
with pytest.raises(ValueError) as e:
73-
image_uris.retrieve(framework, region, instance_type="ml_c5", version=version)
76+
image_uris.retrieve(
77+
framework_in_config, region, instance_type="ml_c5", version=version
78+
)
7479
assert "Unsupported region: {}.".format(region) in str(e.value)
7580

76-
uri = image_uris.retrieve(framework, "us-west-2", instance_type="ml_p2", version=version)
77-
assert _expected_framework_uri(framework, version, processor="gpu") == uri
81+
uri = image_uris.retrieve(
82+
framework_in_config, "us-west-2", instance_type="ml_p2", version=version
83+
)
84+
assert _expected_framework_uri(framework_in_uri, version, processor="gpu") == uri
7885

7986

8087
def test_neo_mxnet(neo_mxnet_version):

0 commit comments

Comments
 (0)