Skip to content

Commit 3f42c15

Browse files
committed
fix: instance type found but image uri for family
1 parent 2f22669 commit 3f42c15

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,10 @@ def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
356356
if None in [self.regional_aliases, self.variants]:
357357
return None
358358

359-
image_uri_alias: Optional[str] = None
360-
if instance_type in self.variants:
361-
image_uri_alias = (
362-
self.variants[instance_type].get("regional_properties", {}).get("image_uri")
363-
)
364-
else:
359+
image_uri_alias: Optional[str] = (
360+
self.variants.get(instance_type, {}).get("regional_properties", {}).get("image_uri")
361+
)
362+
if image_uri_alias is None:
365363
instance_type_family = get_instance_type_family(instance_type)
366364

367365
if instance_type_family in {"", None}:

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def test_jumpstart_instance_variants():
175175
"ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}},
176176
"p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
177177
"g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
178+
"g9": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
178179
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
179180
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
180181
"local": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
@@ -184,6 +185,9 @@ def test_jumpstart_instance_variants():
184185
"ml.g5.12xlarge": {
185186
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
186187
},
188+
"ml.g9.12xlarge": {
189+
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
190+
},
187191
},
188192
}
189193
)
@@ -192,6 +196,12 @@ def test_jumpstart_instance_variants():
192196
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu"
193197
)
194198

199+
assert (
200+
variant.get_image_uri(instance_type="ml.g9.12xlarge", region="us-west-2")
201+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:"
202+
"1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04"
203+
)
204+
195205
assert (
196206
variant.get_image_uri(instance_type="ml.p3.2xlarge", region="us-west-2")
197207
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:"

0 commit comments

Comments
 (0)