Skip to content

Commit e0cc4e1

Browse files
author
Chuyang Deng
committed
fix dict KeyError and and unit tests
1 parent 1e7799f commit e0cc4e1

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

src/sagemaker/cli/framework_upgrade.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
IMAGE_URI_CONFIG_DIR = os.path.join("..", "image_uri_config")
2323

2424

25-
def _get_latest_values(existing_content, scope=None):
25+
def get_latest_values(existing_content, scope=None):
2626
"""Get the latest "registries", "py_versions" and "repository" values
2727
2828
Args:
@@ -40,10 +40,10 @@ def _get_latest_values(existing_content, scope=None):
4040
)
4141
)
4242

43-
latest_version = list(existing_content[scope]["versions"].keys())[-1]
44-
registries = existing_content[scope]["versions"][latest_version]["registries"]
45-
py_versions = existing_content[scope]["versions"][latest_version]["py_versions"]
46-
repository = existing_content[scope]["versions"][latest_version]["repository"]
43+
latest_version = list(existing_content["versions"].keys())[-1]
44+
registries = existing_content["versions"][latest_version]["registries"]
45+
py_versions = existing_content["versions"][latest_version]["py_versions"]
46+
repository = existing_content["versions"][latest_version]["repository"]
4747

4848
return registries, py_versions, repository
4949

@@ -173,7 +173,7 @@ def add_version(
173173
"""
174174
py_versions = py_versions.split(",")
175175
processors = processors.split(",")
176-
latest_registries, latest_py_versions, latest_repository = _get_latest_values(
176+
latest_registries, latest_py_versions, latest_repository = get_latest_values(
177177
existing_content, scope
178178
)
179179
if not py_versions:

tests/unit/sagemaker/cli/test_framework_version_upgrade_tool.py renamed to tests/unit/sagemaker/cli/test_framework_upgrade_tool.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,22 @@ def test_algo_add_region(algo_region_content, algo_expected_region_content):
218218
account = "987654321098"
219219
framework_upgrade.add_region(algo_region_content, region, account)
220220
assert algo_region_content == algo_expected_region_content
221+
222+
223+
def test_dlc_get_latest_content(dlc_content):
224+
latest_version = "1.0.0"
225+
scope = "eia"
226+
registries, py_versions, repository = framework_upgrade.get_latest_values(
227+
dlc_content, scope=scope
228+
)
229+
assert registries == dlc_content[scope]["versions"][latest_version]["registries"]
230+
assert py_versions == dlc_content[scope]["versions"][latest_version]["py_versions"]
231+
assert repository == dlc_content[scope]["versions"][latest_version]["repository"]
232+
233+
234+
def test_algo_get_latest_content(algo_content):
235+
latest_version = "0.10"
236+
registries, py_versions, repository = framework_upgrade.get_latest_values(algo_content)
237+
assert registries == algo_content["versions"][latest_version]["registries"]
238+
assert py_versions == algo_content["versions"][latest_version]["py_versions"]
239+
assert repository == algo_content["versions"][latest_version]["repository"]

0 commit comments

Comments
 (0)