Skip to content

Commit 2efa1e8

Browse files
committed
feat: jumpstart instance type variants
1 parent 017160f commit 2efa1e8

File tree

7 files changed

+488
-21
lines changed

7 files changed

+488
-21
lines changed

src/sagemaker/image_uris.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -262,20 +262,6 @@ def retrieve(
262262
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo)
263263

264264

265-
def _get_instance_type_family(instance_type):
266-
"""Return the family of the instance type.
267-
268-
Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
269-
or there is no match, return an empty string.
270-
"""
271-
instance_type_family = ""
272-
if isinstance(instance_type, str):
273-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
274-
if match is not None:
275-
instance_type_family = match[1]
276-
return instance_type_family
277-
278-
279265
def _get_image_tag(
280266
container_version,
281267
distribution,
@@ -289,7 +275,7 @@ def _get_image_tag(
289275
version,
290276
):
291277
"""Return image tag based on framework, container, and compute configuration(s)."""
292-
instance_type_family = _get_instance_type_family(instance_type)
278+
instance_type_family = utils.get_instance_type_family(instance_type)
293279
if framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
294280
if instance_type_family and final_image_scope == INFERENCE_GRAVITON:
295281
_validate_arg(
@@ -377,7 +363,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
377363

378364
def _validate_instance_deprecation(framework, instance_type, version):
379365
"""Check if instance type is deprecated for a certain framework with a certain version"""
380-
if _get_instance_type_family(instance_type) == "p2":
366+
if utils.get_instance_type_family(instance_type) == "p2":
381367
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
382368
framework == "tensorflow" and Version(version) >= Version("2.12")
383369
):
@@ -401,7 +387,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
401387
# Validate for Graviton allowed frameowrks
402388
if (
403389
instance_type is not None
404-
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
390+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
405391
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
406392
):
407393
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
@@ -418,7 +404,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
418404
"""Return final image scope based on provided framework and instance type."""
419405
if (
420406
framework in GRAVITON_ALLOWED_FRAMEWORKS
421-
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
407+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
422408
):
423409
return INFERENCE_GRAVITON
424410
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
@@ -433,7 +419,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
433419
def _get_inference_tool(inference_tool, instance_type):
434420
"""Extract the inference tool name from instance type."""
435421
if not inference_tool:
436-
instance_type_family = _get_instance_type_family(instance_type)
422+
instance_type_family = utils.get_instance_type_family(instance_type)
437423
if instance_type_family.startswith("inf") or instance_type_family.startswith("trn"):
438424
return "neuron"
439425
return inference_tool
@@ -517,7 +503,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
517503
processor = "neuron"
518504
else:
519505
# looks for either "ml.<family>.<size>" or "ml_<family>"
520-
family = _get_instance_type_family(instance_type)
506+
family = utils.get_instance_type_family(instance_type)
521507
if family:
522508
# For some frameworks, we have optimized images for specific families, e.g c5 or p3.
523509
# In those cases, we use the family name in the image tag. In other cases, we use
@@ -547,7 +533,7 @@ def _should_auto_select_container_version(instance_type, distribution):
547533
p4d = False
548534
if instance_type:
549535
# looks for either "ml.<family>.<size>" or "ml_<family>"
550-
family = _get_instance_type_family(instance_type)
536+
family = utils.get_instance_type_family(instance_type)
551537
if family:
552538
p4d = family == "p4d"
553539

src/sagemaker/jumpstart/artifacts/image_uris.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,30 @@ def _retrieve_image_uri(
111111
)
112112

113113
if image_scope == JumpStartScriptScope.INFERENCE:
114+
hosting_instance_type_variants = model_specs.hosting_instance_type_variants
115+
if hosting_instance_type_variants:
116+
image_uri = hosting_instance_type_variants.get_image_uri(
117+
instance_type=instance_type, region=region
118+
)
119+
if image_uri is None:
120+
raise ValueError(
121+
f"Inference image uri is unavailable for model id '{model_id}' "
122+
f"with '{instance_type}' instance type in '{region}' region."
123+
)
124+
return image_uri
114125
ecr_specs = model_specs.hosting_ecr_specs
115126
elif image_scope == JumpStartScriptScope.TRAINING:
127+
training_instance_type_variants = model_specs.training_instance_type_variants
128+
if training_instance_type_variants:
129+
image_uri = training_instance_type_variants.get_image_uri(
130+
instance_type=instance_type, region=region
131+
)
132+
if image_uri is None:
133+
raise ValueError(
134+
f"Training image uri is unavailable for model id '{model_id}' "
135+
f"with '{instance_type}' instance type in '{region}' region."
136+
)
137+
return image_uri
116138
ecr_specs = model_specs.training_ecr_specs
117139

118140
if framework is not None and framework != ecr_specs.framework:

src/sagemaker/jumpstart/types.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from copy import deepcopy
1616
from enum import Enum
1717
from typing import Any, Dict, List, Optional, Set, Union
18+
from sagemaker.utils import get_instance_type_family
1819

1920

2021
class JumpStartDataHolderType:
@@ -309,6 +310,69 @@ def to_json(self) -> Dict[str, Any]:
309310
return json_obj
310311

311312

313+
class JumpStartInstanceTypeVariants(JumpStartDataHolderType):
314+
"""Data class for JumpStart instance type variants."""
315+
316+
__slots__ = [
317+
"aliases",
318+
"variants",
319+
]
320+
321+
def __init__(self, spec: Optional[Dict[str, Any]]):
322+
"""Initializes a JumpStartInstanceTypeVariants object from its json representation.
323+
324+
Args:
325+
spec (Dict[str, Any]): Dictionary representation of instance type variants.
326+
"""
327+
self.from_json(spec)
328+
329+
def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
330+
"""Sets fields in object based on json.
331+
332+
Args:
333+
json_obj (Dict[str, Any]): Dictionary representation of instance type variants.
334+
"""
335+
336+
if json_obj is None:
337+
return
338+
339+
self.aliases: dict = json_obj["aliases"]
340+
self.variant: dict = json_obj["variants"]
341+
342+
def to_json(self) -> Dict[str, Any]:
343+
"""Returns json representation of JumpStartPredictorSpecs object."""
344+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
345+
return json_obj
346+
347+
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
348+
"""Returns image uri from instance type and region.
349+
350+
Returns None if no instance type is available or found.
351+
"""
352+
353+
image_uri_alias: Optional[str] = None
354+
if instance_type in self.variants:
355+
image_uri_alias = self.variants[instance_type]["properties"].get("image_uri")
356+
else:
357+
instance_type_family = get_instance_type_family(instance_type)
358+
image_uri_alias = (
359+
self.variants[instance_type_family]["properties"].get("image_uri")
360+
if instance_type_family in self.variants
361+
else None
362+
)
363+
364+
if image_uri_alias is None:
365+
return image_uri_alias
366+
367+
if not image_uri_alias.startswith("$"):
368+
raise TypeError("All image uris should map to an alias an start with '$'.")
369+
370+
if region not in self.aliases:
371+
return None
372+
alias_value = self.aliases[region][image_uri_alias[1:]]
373+
return alias_value
374+
375+
312376
class JumpStartModelSpecs(JumpStartDataHolderType):
313377
"""Data class JumpStart model specs."""
314378

@@ -357,6 +421,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
357421
"hosting_model_package_arns",
358422
"training_model_package_artifact_uris",
359423
"hosting_use_script_uri",
424+
"hosting_instance_type_variants",
425+
"training_instance_type_variants",
360426
]
361427

362428
def __init__(self, spec: Dict[str, Any]):
@@ -432,6 +498,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
432498
self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns")
433499
self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
434500

501+
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
502+
JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"])
503+
if json_obj.get("hosting_instance_type_variants")
504+
else None
505+
)
506+
435507
if self.training_supported:
436508
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
437509
json_obj["training_ecr_specs"]
@@ -453,6 +525,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
453525
self.training_model_package_artifact_uris: Optional[Dict] = json_obj.get(
454526
"training_model_package_artifact_uris"
455527
)
528+
self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
529+
JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"])
530+
if json_obj.get("training_instance_type_variants")
531+
else None
532+
)
456533

457534
def to_json(self) -> Dict[str, Any]:
458535
"""Returns json representation of JumpStartModelSpecs object."""

src/sagemaker/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,3 +1435,17 @@ def instance_supports_kms(instance_type: str) -> bool:
14351435
ValueError: If the instance type is improperly formatted.
14361436
"""
14371437
return volume_size_supported(instance_type)
1438+
1439+
1440+
def get_instance_type_family(instance_type):
1441+
"""Return the family of the instance type.
1442+
1443+
Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
1444+
or there is no match, return an empty string.
1445+
"""
1446+
instance_type_family = ""
1447+
if isinstance(instance_type, str):
1448+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1449+
if match is not None:
1450+
instance_type_family = match[1]
1451+
return instance_type_family
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
import pytest
17+
18+
from sagemaker import image_uris
19+
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
20+
21+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
22+
23+
24+
@patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
26+
def test_jumpstart_variants_image_uri(
27+
patched_get_model_specs, patched_verify_model_region_and_return_specs
28+
):
29+
30+
patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
31+
patched_get_model_specs.side_effect = get_special_model_spec
32+
33+
assert (
34+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
35+
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04"
36+
== image_uris.retrieve(
37+
framework=None,
38+
region="us-west-2",
39+
image_scope="inference",
40+
model_id="variant-model",
41+
model_version="*",
42+
instance_type="ml.p2.xlarge",
43+
)
44+
)
45+
46+
assert "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah" == image_uris.retrieve(
47+
framework=None,
48+
region="us-west-2",
49+
image_scope="inference",
50+
model_id="variant-model",
51+
model_version="*",
52+
instance_type="ml.c2.xlarge",
53+
)
54+
55+
with pytest.raises(ValueError):
56+
image_uris.retrieve(
57+
framework=None,
58+
region="us-west-29",
59+
image_scope="inference",
60+
model_id="variant-model",
61+
model_version="*",
62+
instance_type="ml.c2.xlarge",
63+
)
64+
65+
with pytest.raises(ValueError):
66+
image_uris.retrieve(
67+
framework=None,
68+
region="us-west-2",
69+
image_scope="inference",
70+
model_id="variant-model",
71+
model_version="*",
72+
instance_type="ml.c200000.xlarge",
73+
)
74+
75+
assert (
76+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3"
77+
== image_uris.retrieve(
78+
framework=None,
79+
region="us-west-2",
80+
image_scope="training",
81+
model_id="variant-model",
82+
model_version="*",
83+
instance_type="ml.g4dn.2xlarge",
84+
)
85+
)

0 commit comments

Comments
 (0)