Skip to content

Commit 2d77678

Browse files
committed
chore: address PR comments
1 parent d76b900 commit 2d77678

File tree

2 files changed

+250
-39
lines changed

2 files changed

+250
-39
lines changed

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,61 @@
2828
verify_model_region_and_return_specs,
2929
)
3030
from sagemaker.session import Session
31+
from sagemaker.jumpstart.types import JumpStartModelSpecs
32+
33+
34+
def _retrieve_hosting_prepacked_artifact_key(
35+
model_specs: JumpStartModelSpecs, instance_type: str
36+
) -> str:
37+
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
38+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
39+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
40+
instance_type=instance_type
41+
)
42+
if instance_type
43+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
44+
else None
45+
)
46+
47+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
48+
model_specs, "hosting_prepacked_artifact_key"
49+
)
50+
51+
return (
52+
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53+
)
54+
55+
56+
def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
57+
"""Returns instance specific hosting artifact key or default one as fallback."""
58+
instance_specific_hosting_artifact_key: Optional[str] = (
59+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
60+
instance_type=instance_type
61+
)
62+
if instance_type
63+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
64+
else None
65+
)
66+
67+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
68+
69+
return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70+
71+
72+
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
73+
"""Returns instance specific training artifact key or default one as fallback."""
74+
instance_specific_training_artifact_key: Optional[str] = (
75+
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
76+
instance_type=instance_type
77+
)
78+
if instance_type
79+
and getattr(model_specs, "training_instance_type_variants", None) is not None
80+
else None
81+
)
82+
83+
default_training_artifact_key: str = model_specs.training_artifact_key
84+
85+
return instance_specific_training_artifact_key or default_training_artifact_key
3186

3287

3388
def _retrieve_model_uri(
@@ -90,52 +145,18 @@ def _retrieve_model_uri(
90145
model_artifact_key: str
91146

92147
if model_scope == JumpStartScriptScope.INFERENCE:
93-
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
94-
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
95-
instance_type=instance_type
96-
)
97-
if instance_type
98-
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
99-
else None
100-
)
101-
102-
instance_specific_hosting_artifact_key: Optional[str] = (
103-
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
104-
instance_type=instance_type
105-
)
106-
if instance_type
107-
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
108-
else None
109-
)
110-
111-
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
112-
model_specs, "hosting_prepacked_artifact_key"
113-
)
114148

115-
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
149+
is_prepacked = not model_specs.use_inference_script_uri()
116150

117151
model_artifact_key = (
118-
instance_specific_prepacked_hosting_artifact_key
119-
or instance_specific_hosting_artifact_key
120-
or default_prepacked_hosting_artifact_key
121-
or default_hosting_artifact_key
152+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
153+
if is_prepacked
154+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
122155
)
123156

124157
elif model_scope == JumpStartScriptScope.TRAINING:
125-
instance_specific_training_artifact_key: Optional[str] = (
126-
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
127-
instance_type=instance_type
128-
)
129-
if instance_type
130-
and getattr(model_specs, "training_instance_type_variants", None) is not None
131-
else None
132-
)
133158

134-
default_training_artifact_key: str = model_specs.training_artifact_key
135-
136-
model_artifact_key = (
137-
instance_specific_training_artifact_key or default_training_artifact_key
138-
)
159+
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
139160

140161
bucket = os.environ.get(
141162
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

tests/unit/sagemaker/jumpstart/test_artifacts.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,201 @@
1616

1717
from mock.mock import patch
1818

19+
import copy
1920
from sagemaker.jumpstart import artifacts
21+
from sagemaker.jumpstart.artifacts.model_uris import (
22+
_retrieve_hosting_prepacked_artifact_key,
23+
_retrieve_hosting_artifact_key,
24+
_retrieve_training_artifact_key,
25+
)
26+
from sagemaker.jumpstart.types import JumpStartModelSpecs
27+
from tests.unit.sagemaker.jumpstart.constants import (
28+
BASE_SPEC,
29+
)
30+
2031

2132
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
2233

2334

35+
class ModelArtifactVariantsTest(unittest.TestCase):
36+
def test_retrieve_hosting_prepacked_artifact_key(self):
37+
38+
test_spec = copy.deepcopy(BASE_SPEC)
39+
40+
test_spec["hosting_prepacked_artifact_key"] = "some/thing"
41+
42+
test_spec["hosting_instance_type_variants"] = {
43+
"regional_aliases": {
44+
"us-west-2": {
45+
"alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.ama"
46+
"zonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118"
47+
}
48+
},
49+
"variants": {
50+
"c4": {
51+
"regional_properties": {
52+
"image_uri": "$alias_ecr_uri_1",
53+
},
54+
"properties": {
55+
"prepacked_artifact_key": "in/the/way",
56+
},
57+
}
58+
},
59+
}
60+
61+
self.assertEqual(
62+
_retrieve_hosting_prepacked_artifact_key(
63+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
64+
),
65+
"in/the/way",
66+
)
67+
68+
test_spec["hosting_prepacked_artifact_key"] = None
69+
70+
self.assertEqual(
71+
_retrieve_hosting_prepacked_artifact_key(
72+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
73+
),
74+
"in/the/way",
75+
)
76+
77+
test_spec["hosting_instance_type_variants"] = None
78+
79+
self.assertEqual(
80+
_retrieve_hosting_prepacked_artifact_key(
81+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
82+
),
83+
None,
84+
)
85+
86+
test_spec["hosting_prepacked_artifact_key"] = "shemoves"
87+
88+
self.assertEqual(
89+
_retrieve_hosting_prepacked_artifact_key(
90+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
91+
),
92+
"shemoves",
93+
)
94+
95+
def test_retrieve_hosting_artifact_key(self):
96+
97+
test_spec = copy.deepcopy(BASE_SPEC)
98+
99+
test_spec["hosting_artifact_key"] = "some/thing"
100+
101+
test_spec["hosting_instance_type_variants"] = {
102+
"regional_aliases": {
103+
"us-west-2": {
104+
"alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.ama"
105+
"zonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118"
106+
}
107+
},
108+
"variants": {
109+
"c4": {
110+
"regional_properties": {
111+
"image_uri": "$alias_ecr_uri_1",
112+
},
113+
"properties": {
114+
"artifact_key": "in/the/way",
115+
},
116+
}
117+
},
118+
}
119+
120+
self.assertEqual(
121+
_retrieve_hosting_artifact_key(
122+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
123+
),
124+
"in/the/way",
125+
)
126+
127+
test_spec["hosting_artifact_key"] = None
128+
129+
self.assertEqual(
130+
_retrieve_hosting_artifact_key(
131+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
132+
),
133+
"in/the/way",
134+
)
135+
136+
test_spec["hosting_instance_type_variants"] = None
137+
138+
self.assertEqual(
139+
_retrieve_hosting_artifact_key(
140+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
141+
),
142+
None,
143+
)
144+
145+
test_spec["hosting_artifact_key"] = "shemoves"
146+
147+
self.assertEqual(
148+
_retrieve_hosting_artifact_key(
149+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
150+
),
151+
"shemoves",
152+
)
153+
154+
def test_retrieve_training_artifact_key(self):
155+
156+
test_spec = copy.deepcopy(BASE_SPEC)
157+
158+
test_spec["training_artifact_key"] = "some/thing"
159+
160+
test_spec["training_instance_type_variants"] = {
161+
"regional_aliases": {
162+
"us-west-2": {
163+
"alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2."
164+
"amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118"
165+
}
166+
},
167+
"variants": {
168+
"c4": {
169+
"regional_properties": {
170+
"image_uri": "$alias_ecr_uri_1",
171+
},
172+
"properties": {
173+
"artifact_key": "in/the/way",
174+
},
175+
}
176+
},
177+
}
178+
179+
self.assertEqual(
180+
_retrieve_training_artifact_key(
181+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
182+
),
183+
"in/the/way",
184+
)
185+
186+
test_spec["training_artifact_key"] = None
187+
188+
self.assertEqual(
189+
_retrieve_training_artifact_key(
190+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
191+
),
192+
"in/the/way",
193+
)
194+
195+
test_spec["training_instance_type_variants"] = None
196+
197+
self.assertEqual(
198+
_retrieve_training_artifact_key(
199+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
200+
),
201+
None,
202+
)
203+
204+
test_spec["training_artifact_key"] = "shemoves"
205+
206+
self.assertEqual(
207+
_retrieve_training_artifact_key(
208+
JumpStartModelSpecs(test_spec), instance_type="ml.c4.xlarge"
209+
),
210+
"shemoves",
211+
)
212+
213+
24214
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
25215
class RetrieveKwargsTest(unittest.TestCase):
26216

0 commit comments

Comments
 (0)