Skip to content

Commit dba2dbb

Browse files
authored
feat: jumpstart model artifact instance type variants (#4172)
1 parent accd220 commit dba2dbb

File tree

10 files changed

+780
-12
lines changed

10 files changed

+780
-12
lines changed

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,68 @@
2828
verify_model_region_and_return_specs,
2929
)
3030
from sagemaker.session import Session
31+
from sagemaker.jumpstart.types import JumpStartModelSpecs
32+
33+
34+
def _retrieve_hosting_prepacked_artifact_key(
35+
model_specs: JumpStartModelSpecs, instance_type: str
36+
) -> str:
37+
"""Returns instance specific hosting prepacked artifact key or default one as fallback."""
38+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
39+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
40+
instance_type=instance_type
41+
)
42+
if instance_type
43+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
44+
else None
45+
)
46+
47+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
48+
model_specs, "hosting_prepacked_artifact_key"
49+
)
50+
51+
return (
52+
instance_specific_prepacked_hosting_artifact_key or default_prepacked_hosting_artifact_key
53+
)
54+
55+
56+
def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
57+
"""Returns instance specific hosting artifact key or default one as fallback."""
58+
instance_specific_hosting_artifact_key: Optional[str] = (
59+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
60+
instance_type=instance_type
61+
)
62+
if instance_type
63+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
64+
else None
65+
)
66+
67+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
68+
69+
return instance_specific_hosting_artifact_key or default_hosting_artifact_key
70+
71+
72+
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
73+
"""Returns instance specific training artifact key or default one as fallback."""
74+
instance_specific_training_artifact_key: Optional[str] = (
75+
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
76+
instance_type=instance_type
77+
)
78+
if instance_type
79+
and getattr(model_specs, "training_instance_type_variants", None) is not None
80+
else None
81+
)
82+
83+
default_training_artifact_key: str = model_specs.training_artifact_key
84+
85+
return instance_specific_training_artifact_key or default_training_artifact_key
3186

3287

3388
def _retrieve_model_uri(
3489
model_id: str,
3590
model_version: str,
3691
model_scope: Optional[str] = None,
92+
instance_type: Optional[str] = None,
3793
region: Optional[str] = None,
3894
tolerate_vulnerable_model: bool = False,
3995
tolerate_deprecated_model: bool = False,
@@ -50,6 +106,7 @@ def _retrieve_model_uri(
50106
artifact S3 URI.
51107
model_scope (str): The model type, i.e. what it is used for.
52108
Valid values: "training" and "inference".
109+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
53110
region (str): Region for which to retrieve model S3 URI. (Default: None).
54111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
55112
specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +141,21 @@ def _retrieve_model_uri(
84141
sagemaker_session=sagemaker_session,
85142
)
86143

144+
model_artifact_key: str
145+
87146
if model_scope == JumpStartScriptScope.INFERENCE:
147+
148+
is_prepacked = not model_specs.use_inference_script_uri()
149+
88150
model_artifact_key = (
89-
getattr(model_specs, "hosting_prepacked_artifact_key", None)
90-
or model_specs.hosting_artifact_key
151+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
152+
if is_prepacked
153+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
91154
)
92155

93156
elif model_scope == JumpStartScriptScope.TRAINING:
94-
model_artifact_key = model_specs.training_artifact_key
157+
158+
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
95159

96160
bucket = os.environ.get(
97161
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
475475
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
476476
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
477477
sagemaker_session=kwargs.sagemaker_session,
478+
instance_type=kwargs.instance_type,
478479
)
479480

480481
if (

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
215215
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
216216
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
217217
sagemaker_session=kwargs.sagemaker_session,
218+
instance_type=kwargs.instance_type,
218219
)
219220

220221
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):

src/sagemaker/jumpstart/types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,62 @@ def to_json(self) -> Dict[str, Any]:
403403
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
404404
return json_obj
405405

406+
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
407+
"""Returns instance specific model artifact key.
408+
409+
Returns None if a model, instance type tuple does not have specific
410+
artifact key.
411+
"""
412+
413+
return self._get_instance_specific_property(
414+
instance_type=instance_type, property_name="prepacked_artifact_key"
415+
)
416+
417+
def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
418+
"""Returns instance specific model artifact key.
419+
420+
Returns None if a model, instance type tuple does not have specific
421+
artifact key.
422+
"""
423+
424+
return self._get_instance_specific_property(
425+
instance_type=instance_type, property_name="artifact_key"
426+
)
427+
428+
def _get_instance_specific_property(
429+
self, instance_type: str, property_name: str
430+
) -> Optional[str]:
431+
"""Returns instance specific property.
432+
433+
If a value exists for both the instance family and instance type,
434+
the instance type value is chosen.
435+
436+
Returns None if a (model, instance type, property name) tuple does not have
437+
specific prepacked artifact key.
438+
"""
439+
440+
if self.variants is None:
441+
return None
442+
443+
instance_specific_property: Optional[str] = (
444+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
445+
)
446+
447+
if instance_specific_property:
448+
return instance_specific_property
449+
450+
instance_type_family = get_instance_type_family(instance_type)
451+
452+
instance_family_property: Optional[str] = (
453+
self.variants.get(instance_type_family, {})
454+
.get("properties", {})
455+
.get(property_name, None)
456+
if instance_type_family not in {"", None}
457+
else None
458+
)
459+
460+
return instance_family_property
461+
406462
def get_instance_specific_hyperparameters(
407463
self, instance_type: str
408464
) -> List[JumpStartHyperparameter]:

src/sagemaker/model_uris.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve(
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
3232
model_scope: Optional[str] = None,
33+
instance_type: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -44,6 +45,7 @@ def retrieve(
4445
the model artifact S3 URI.
4546
model_scope (str): The model type.
4647
Valid values: "training" and "inference".
48+
instance_type (str): The ML compute instance type for the specified scope. (Default: None).
4749
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
4850
specifications should be tolerated without raising an exception. If ``False``, raises an
4951
exception if the script used by this version of the model has dependencies with known
@@ -71,11 +73,12 @@ def retrieve(
7173
)
7274

7375
return artifacts._retrieve_model_uri(
74-
model_id,
75-
model_version, # type: ignore
76-
model_scope,
77-
region,
78-
tolerate_vulnerable_model,
79-
tolerate_deprecated_model,
76+
model_id=model_id,
77+
model_version=model_version, # type: ignore
78+
model_scope=model_scope,
79+
instance_type=instance_type,
80+
region=region,
81+
tolerate_vulnerable_model=tolerate_vulnerable_model,
82+
tolerate_deprecated_model=tolerate_deprecated_model,
8083
sagemaker_session=sagemaker_session,
8184
)

0 commit comments

Comments
 (0)