Skip to content

Commit ce47702

Browse files
evakravigoelakash
authored andcommitted
feat: jumpstart model artifact instance type variants
1 parent 5d0f6b3 commit ce47702

File tree

9 files changed

+576
-13
lines changed

9 files changed

+576
-13
lines changed

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _retrieve_model_uri(
3434
model_id: str,
3535
model_version: str,
3636
model_scope: Optional[str] = None,
37+
instance_type: Optional[str] = None,
3738
region: Optional[str] = None,
3839
tolerate_vulnerable_model: bool = False,
3940
tolerate_deprecated_model: bool = False,
@@ -50,6 +51,8 @@ def _retrieve_model_uri(
5051
artifact S3 URI.
5152
model_scope (str): The model type, i.e. what it is used for.
5253
Valid values: "training" and "inference".
54+
instance_type (str): An instance type to optionally supply in order to get
55+
model artifacts specific for the instance type. (Default: None).
5356
region (str): Region for which to retrieve model S3 URI. (Default: None).
5457
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5558
specifications should be tolerated (exception not raised). If False, raises an
@@ -84,14 +87,55 @@ def _retrieve_model_uri(
8487
sagemaker_session=sagemaker_session,
8588
)
8689

90+
model_artifact_key: str
91+
8792
if model_scope == JumpStartScriptScope.INFERENCE:
93+
instance_specific_prepacked_hosting_artifact_key: Optional[str] = (
94+
model_specs.hosting_instance_type_variants.get_instance_specific_prepacked_artifact_key(
95+
instance_type=instance_type
96+
)
97+
if instance_type
98+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
99+
else None
100+
)
101+
102+
instance_specific_hosting_artifact_key: Optional[str] = (
103+
model_specs.hosting_instance_type_variants.get_instance_specific_artifact_key(
104+
instance_type=instance_type
105+
)
106+
if instance_type
107+
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
108+
else None
109+
)
110+
111+
default_prepacked_hosting_artifact_key: Optional[str] = getattr(
112+
model_specs, "hosting_prepacked_artifact_key"
113+
)
114+
115+
default_hosting_artifact_key: str = model_specs.hosting_artifact_key
116+
88117
model_artifact_key = (
89-
getattr(model_specs, "hosting_prepacked_artifact_key", None)
90-
or model_specs.hosting_artifact_key
118+
instance_specific_prepacked_hosting_artifact_key
119+
or instance_specific_hosting_artifact_key
120+
or default_prepacked_hosting_artifact_key
121+
or default_hosting_artifact_key
91122
)
92123

93124
elif model_scope == JumpStartScriptScope.TRAINING:
94-
model_artifact_key = model_specs.training_artifact_key
125+
instance_specific_training_artifact_key: Optional[str] = (
126+
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
127+
instance_type=instance_type
128+
)
129+
if instance_type
130+
and getattr(model_specs, "training_instance_type_variants", None) is not None
131+
else None
132+
)
133+
134+
default_training_artifact_key: str = model_specs.training_artifact_key
135+
136+
model_artifact_key = (
137+
instance_specific_training_artifact_key or default_training_artifact_key
138+
)
95139

96140
bucket = os.environ.get(
97141
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
473473
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
474474
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
475475
sagemaker_session=kwargs.sagemaker_session,
476+
instance_type=kwargs.instance_type,
476477
)
477478

478479
if (

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
215215
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
216216
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
217217
sagemaker_session=kwargs.sagemaker_session,
218+
instance_type=kwargs.instance_type,
218219
)
219220

220221
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):

src/sagemaker/jumpstart/types.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,62 @@ def to_json(self) -> Dict[str, Any]:
346346
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
347347
return json_obj
348348

349+
def get_instance_specific_prepacked_artifact_key(self, instance_type: str) -> Optional[str]:
350+
"""Returns instance specific model artifact key.
351+
352+
Returns None if a model, instance type tuple does not have specific
353+
artifact key.
354+
"""
355+
356+
return self._get_instance_specific_property(
357+
instance_type=instance_type, property_name="prepacked_artifact_key"
358+
)
359+
360+
def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str]:
361+
"""Returns instance specific model artifact key.
362+
363+
Returns None if a model, instance type tuple does not have specific
364+
artifact key.
365+
"""
366+
367+
return self._get_instance_specific_property(
368+
instance_type=instance_type, property_name="artifact_key"
369+
)
370+
371+
def _get_instance_specific_property(
372+
self, instance_type: str, property_name: str
373+
) -> Optional[str]:
374+
"""Returns instance specific property.
375+
376+
If a value exists for both the instance family and instance type,
377+
the instance type value is chosen.
378+
379+
Returns None if a (model, instance type, property name) tuple does not have
380+
specific prepacked artifact key.
381+
"""
382+
383+
if self.variants is None:
384+
return None
385+
386+
instance_specific_property: Optional[str] = (
387+
self.variants.get(instance_type, {}).get("properties", {}).get(property_name, None)
388+
)
389+
390+
if instance_specific_property:
391+
return instance_specific_property
392+
393+
instance_type_family = get_instance_type_family(instance_type)
394+
395+
instance_family_property: Optional[str] = (
396+
self.variants.get(instance_type_family, {})
397+
.get("properties", {})
398+
.get(property_name, None)
399+
if instance_type_family not in {"", None}
400+
else None
401+
)
402+
403+
return instance_family_property
404+
349405
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
350406
"""Returns instance specific environment variables.
351407

src/sagemaker/model_uris.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve(
3030
model_id: Optional[str] = None,
3131
model_version: Optional[str] = None,
3232
model_scope: Optional[str] = None,
33+
instance_type: Optional[str] = None,
3334
tolerate_vulnerable_model: bool = False,
3435
tolerate_deprecated_model: bool = False,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -44,6 +45,8 @@ def retrieve(
4445
the model artifact S3 URI.
4546
model_scope (str): The model type.
4647
Valid values: "training" and "inference".
48+
instance_type (str): An instance type to optionally supply in order to get
49+
model artifacts specific for the instance type. (Default: None).
4750
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
4851
specifications should be tolerated without raising an exception. If ``False``, raises an
4952
exception if the script used by this version of the model has dependencies with known
@@ -71,11 +74,12 @@ def retrieve(
7174
)
7275

7376
return artifacts._retrieve_model_uri(
74-
model_id,
75-
model_version, # type: ignore
76-
model_scope,
77-
region,
78-
tolerate_vulnerable_model,
79-
tolerate_deprecated_model,
77+
model_id=model_id,
78+
model_version=model_version, # type: ignore
79+
model_scope=model_scope,
80+
instance_type=instance_type,
81+
region=region,
82+
tolerate_vulnerable_model=tolerate_vulnerable_model,
83+
tolerate_deprecated_model=tolerate_deprecated_model,
8084
sagemaker_session=sagemaker_session,
8185
)

0 commit comments

Comments
 (0)