Skip to content

Commit 8bb9cba

Browse files
committed
feat: combined model + script artifact
1 parent ca22ac6 commit 8bb9cba

File tree

5 files changed

+177
-6
lines changed

5 files changed

+177
-6
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def _retrieve_image_uri(
173173
def _retrieve_model_uri(
174174
model_id: str,
175175
model_version: str,
176-
model_scope: Optional[str],
177-
region: Optional[str],
178-
tolerate_vulnerable_model: bool,
179-
tolerate_deprecated_model: bool,
176+
model_scope: Optional[str] = None,
177+
region: Optional[str] = None,
178+
tolerate_vulnerable_model: bool = False,
179+
tolerate_deprecated_model: bool = False,
180+
include_script: bool = False,
180181
):
181182
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
182183
@@ -197,6 +198,8 @@ def _retrieve_model_uri(
197198
tolerate_deprecated_model (bool): True if deprecated versions of model
198199
specifications should be tolerated (exception not raised). If False, raises
199200
an exception if the version of the model is deprecated.
201+
include_script (bool): True if script artifact should be packaged alongside model
202+
tarball. (Default: False).
200203
Returns:
201204
str: the model artifact S3 URI for the corresponding model.
202205
@@ -205,6 +208,8 @@ def _retrieve_model_uri(
205208
VulnerableJumpStartModelError: If any of the dependencies required by the script have
206209
known security vulnerabilities.
207210
DeprecatedJumpStartModelError: If the version of the model is deprecated.
211+
NotImplementedError: If the combination of arguments doesn't support combined model
212+
and script artifact.
208213
"""
209214
if region is None:
210215
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -218,10 +223,24 @@ def _retrieve_model_uri(
218223
tolerate_deprecated_model=tolerate_deprecated_model,
219224
)
220225

226+
error_msg_no_combined_artifact = (
227+
"No combined script + model tarball available "
228+
f"for {model_id} with version {model_version} for {model_scope}."
229+
)
230+
221231
if model_scope == JumpStartScriptScope.INFERENCE:
222-
model_artifact_key = model_specs.hosting_artifact_key
232+
if not include_script:
233+
model_artifact_key = model_specs.hosting_artifact_key
234+
else:
235+
model_artifact_key = getattr(model_specs, "hosting_prepacked_artifact_key", None)
236+
if model_artifact_key is None:
237+
raise NotImplementedError(error_msg_no_combined_artifact)
238+
223239
elif model_scope == JumpStartScriptScope.TRAINING:
224-
model_artifact_key = model_specs.training_artifact_key
240+
if not include_script:
241+
model_artifact_key = model_specs.training_artifact_key
242+
else:
243+
raise NotImplementedError(error_msg_no_combined_artifact)
225244

226245
bucket = os.environ.get(
227246
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
293293
"training_vulnerabilities",
294294
"deprecated",
295295
"metrics",
296+
"hosting_prepacked_artifact_key",
296297
]
297298

298299
def __init__(self, spec: Dict[str, Any]):
@@ -330,6 +331,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
330331
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
331332
self.deprecated: bool = bool(json_obj["deprecated"])
332333
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
334+
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
335+
"hosting_prepacked_artifact_key", None
336+
)
333337

334338
if self.training_supported:
335339
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

src/sagemaker/model_uris.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def retrieve(
3030
model_scope: Optional[str] = None,
3131
tolerate_vulnerable_model: bool = False,
3232
tolerate_deprecated_model: bool = False,
33+
include_script: bool = False,
3334
) -> str:
3435
"""Retrieves the model artifact Amazon S3 URI for the model matching the given arguments.
3536
@@ -48,6 +49,8 @@ def retrieve(
4849
tolerate_deprecated_model (bool): ``True`` if deprecated versions of model
4950
specifications should be tolerated without raising an exception. If ``False``, raises
5051
an exception if the version of the model is deprecated. (Default: False).
52+
include_script (bool): True if script artifact should be packaged alongside model
53+
tarball. (Default: False).
5154
Returns:
5255
str: The model artifact S3 URI for the corresponding model.
5356
@@ -57,6 +60,8 @@ def retrieve(
5760
VulnerableJumpStartModelError: If any of the dependencies required by the script have
5861
known security vulnerabilities.
5962
DeprecatedJumpStartModelError: If the version of the model is deprecated.
63+
NotImplementedError: If the combination of arguments doesn't support combined model
64+
and script artifact.
6065
"""
6166
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
6267
raise ValueError("Must specify `model_id` and `model_version` when retrieving model URIs.")
@@ -68,4 +73,5 @@ def retrieve(
6873
region,
6974
tolerate_vulnerable_model,
7075
tolerate_deprecated_model,
76+
include_script,
7177
)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,88 @@
10701070
},
10711071
],
10721072
},
1073+
"huggingface-text2text-flan-t5-xxl-fp16": {
1074+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
1075+
"url": "https://huggingface.co/google/flan-t5-xxl",
1076+
"version": "1.0.0",
1077+
"min_sdk_version": "2.130.0",
1078+
"training_supported": False,
1079+
"incremental_training_supported": False,
1080+
"hosting_ecr_specs": {
1081+
"framework": "pytorch",
1082+
"framework_version": "1.12.0",
1083+
"py_version": "py38",
1084+
"huggingface_transformers_version": "4.17.0",
1085+
},
1086+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
1087+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
1088+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
1089+
"text2text-flan-t5-xxl-fp16.tar.gz",
1090+
"hosting_prepacked_artifact_version": "1.0.0",
1091+
"inference_vulnerable": False,
1092+
"inference_dependencies": [
1093+
"accelerate==0.16.0",
1094+
"bitsandbytes==0.37.0",
1095+
"filelock==3.9.0",
1096+
"huggingface-hub==0.12.0",
1097+
"regex==2022.7.9",
1098+
"tokenizers==0.13.2",
1099+
"transformers==4.26.0",
1100+
],
1101+
"inference_vulnerabilities": [],
1102+
"training_vulnerable": False,
1103+
"training_dependencies": [],
1104+
"training_vulnerabilities": [],
1105+
"deprecated": False,
1106+
"inference_environment_variables": [
1107+
{
1108+
"name": "SAGEMAKER_PROGRAM",
1109+
"type": "text",
1110+
"default": "inference.py",
1111+
"scope": "container",
1112+
},
1113+
{
1114+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1115+
"type": "text",
1116+
"default": "/opt/ml/model/code",
1117+
"scope": "container",
1118+
},
1119+
{
1120+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1121+
"type": "text",
1122+
"default": "20",
1123+
"scope": "container",
1124+
},
1125+
{
1126+
"name": "MODEL_CACHE_ROOT",
1127+
"type": "text",
1128+
"default": "/opt/ml/model",
1129+
"scope": "container",
1130+
},
1131+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1132+
{
1133+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1134+
"type": "text",
1135+
"default": "1",
1136+
"scope": "container",
1137+
},
1138+
{
1139+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1140+
"type": "text",
1141+
"default": "3600",
1142+
"scope": "container",
1143+
},
1144+
],
1145+
"metrics": [],
1146+
"default_inference_instance_type": "ml.g5.12xlarge",
1147+
"supported_inference_instance_types": [
1148+
"ml.g5.12xlarge",
1149+
"ml.g5.24xlarge",
1150+
"ml.p3.8xlarge",
1151+
"ml.p3.16xlarge",
1152+
"ml.g4dn.12xlarge",
1153+
],
1154+
},
10731155
}
10741156

10751157
BASE_SPEC = {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
import pytest
19+
20+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
21+
22+
23+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
24+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
25+
26+
patched_get_model_specs.side_effect = get_prototype_model_spec
27+
28+
model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16"
29+
30+
uri = model_uris.retrieve(
31+
region="us-west-2",
32+
model_scope="inference",
33+
model_id=model_id_combined_model_artifact,
34+
model_version="*",
35+
include_script=True,
36+
)
37+
assert (
38+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
39+
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
40+
)
41+
42+
with pytest.raises(NotImplementedError):
43+
model_uris.retrieve(
44+
region="us-west-2",
45+
model_scope="transfer_learning",
46+
model_id=model_id_combined_model_artifact,
47+
model_version="*",
48+
include_script=True,
49+
)
50+
51+
model_id_combined_model_artifact_unsupported = "xgboost-classification-model"
52+
53+
with pytest.raises(NotImplementedError):
54+
model_uris.retrieve(
55+
region="us-west-2",
56+
model_scope="inference",
57+
model_id=model_id_combined_model_artifact_unsupported,
58+
model_version="*",
59+
include_script=True,
60+
)

0 commit comments

Comments
 (0)