Skip to content

Commit 2102bb7

Browse files
kc1998dpsage-maker
andauthored
Fix: hub model reference arn enum bug when constructing from arn (#4972)
* Fix hub model reference arn enum bug * Add unit test for construct hub model reference arn util * fix broken unit test * formatting: add extra newline after unit test * fix broken unit test * fix formatting * add more newlines around test * codestyle: fix line too long * Revert "codestyle: fix line too long" This reverts commit 0b6867a. * fix test * add missing quote --------- Co-authored-by: parknate@ <[email protected]>
1 parent 8be4568 commit 2102bb7

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

src/sagemaker/jumpstart/hub/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def construct_hub_model_reference_arn_from_inputs(
106106
info = get_info_from_hub_resource_arn(hub_arn)
107107
arn = (
108108
f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/"
109-
f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}"
109+
f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}"
110110
)
111111

112112
return arn

tests/unit/sagemaker/jumpstart/hub/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ def test_construct_hub_model_arn_from_inputs():
9696
)
9797

9898

99+
def test_construct_hub_model_reference_arn_from_inputs():
100+
model_name, version = "pytorch-ic-imagenet-v2", "1.0.2"
101+
hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub"
102+
hub_content_arn_prefix = "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub"
103+
104+
assert (
105+
utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version)
106+
== hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/1.0.2"
107+
)
108+
109+
version = "*"
110+
assert (
111+
utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version)
112+
== hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/*"
113+
)
114+
115+
99116
def test_generate_hub_arn_for_init_kwargs():
100117
hub_name = "my-hub-name"
101118
hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"

0 commit comments

Comments
 (0)