Skip to content

Commit cf777e0

Browse files
committed
feature: combined model + script artifact (aws#3715)
* feat: combined model + script artifact * chore: use ValueError * chore: improve error msg for no combined artifact * fix: jumpstart unit tests * chore: always include inference script if available
1 parent 03d14f7 commit cf777e0

File tree

4 files changed

+257
-6
lines changed

4 files changed

+257
-6
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ def _retrieve_image_uri(
173173
def _retrieve_model_uri(
174174
model_id: str,
175175
model_version: str,
176-
model_scope: Optional[str],
177-
region: Optional[str],
178-
tolerate_vulnerable_model: bool,
179-
tolerate_deprecated_model: bool,
176+
model_scope: Optional[str] = None,
177+
region: Optional[str] = None,
178+
tolerate_vulnerable_model: bool = False,
179+
tolerate_deprecated_model: bool = False,
180180
):
181181
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
182182
@@ -219,7 +219,11 @@ def _retrieve_model_uri(
219219
)
220220

221221
if model_scope == JumpStartScriptScope.INFERENCE:
222-
model_artifact_key = model_specs.hosting_artifact_key
222+
model_artifact_key = (
223+
getattr(model_specs, "hosting_prepacked_artifact_key", None)
224+
or model_specs.hosting_artifact_key
225+
)
226+
223227
elif model_scope == JumpStartScriptScope.TRAINING:
224228
model_artifact_key = model_specs.training_artifact_key
225229

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
294294
"deprecated",
295295
"metrics",
296296
"training_prepacked_script_key",
297+
"hosting_prepacked_artifact_key",
297298
]
298299

299300
def __init__(self, spec: Dict[str, Any]):
@@ -334,6 +335,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
334335
self.training_prepacked_script_key: Optional[str] = json_obj.get(
335336
"training_prepacked_script_key", None
336337
)
338+
self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get(
339+
"hosting_prepacked_artifact_key", None
340+
)
337341

338342
if self.training_supported:
339343
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,210 @@
1414

1515

1616
SPECIAL_MODEL_SPECS_DICT = {
17+
"no-supported-instance-types-model": {
18+
"model_id": "pytorch-ic-mobilenet-v2",
19+
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.49.0",
22+
"training_supported": True,
23+
"incremental_training_supported": True,
24+
"hosting_ecr_specs": {
25+
"framework": "pytorch",
26+
"framework_version": "1.5.0",
27+
"py_version": "py3",
28+
},
29+
"training_ecr_specs": {
30+
"framework": "pytorch",
31+
"framework_version": "1.5.0",
32+
"py_version": "py3",
33+
},
34+
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
35+
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
36+
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
37+
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
38+
"hyperparameters": [
39+
{
40+
"name": "epochs",
41+
"type": "int",
42+
"default": 3,
43+
"min": 1,
44+
"max": 1000,
45+
"scope": "algorithm",
46+
},
47+
{
48+
"name": "adam-learning-rate",
49+
"type": "float",
50+
"default": 0.05,
51+
"min": 1e-08,
52+
"max": 1,
53+
"scope": "algorithm",
54+
},
55+
{
56+
"name": "batch-size",
57+
"type": "int",
58+
"default": 4,
59+
"min": 1,
60+
"max": 1024,
61+
"scope": "algorithm",
62+
},
63+
{
64+
"name": "sagemaker_submit_directory",
65+
"type": "text",
66+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
67+
"scope": "container",
68+
},
69+
{
70+
"name": "sagemaker_program",
71+
"type": "text",
72+
"default": "transfer_learning.py",
73+
"scope": "container",
74+
},
75+
{
76+
"name": "sagemaker_container_log_level",
77+
"type": "text",
78+
"default": "20",
79+
"scope": "container",
80+
},
81+
],
82+
"inference_environment_variables": [
83+
{
84+
"name": "SAGEMAKER_PROGRAM",
85+
"type": "text",
86+
"default": "inference.py",
87+
"scope": "container",
88+
},
89+
{
90+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
91+
"type": "text",
92+
"default": "/opt/ml/model/code",
93+
"scope": "container",
94+
},
95+
{
96+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
97+
"type": "text",
98+
"default": "20",
99+
"scope": "container",
100+
},
101+
{
102+
"name": "MODEL_CACHE_ROOT",
103+
"type": "text",
104+
"default": "/opt/ml/model",
105+
"scope": "container",
106+
},
107+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
108+
{
109+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
110+
"type": "text",
111+
"default": "1",
112+
"scope": "container",
113+
},
114+
{
115+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
116+
"type": "text",
117+
"default": "3600",
118+
"scope": "container",
119+
},
120+
],
121+
"default_inference_instance_type": "",
122+
"supported_inference_instance_types": None,
123+
"default_training_instance_type": None,
124+
"supported_training_instance_types": [],
125+
"inference_vulnerable": False,
126+
"inference_dependencies": [],
127+
"inference_vulnerabilities": [],
128+
"training_vulnerable": False,
129+
"training_dependencies": [],
130+
"training_vulnerabilities": [],
131+
"deprecated": False,
132+
"metrics": [],
133+
},
134+
"huggingface-text2text-flan-t5-xxl-fp16": {
135+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
136+
"url": "https://huggingface.co/google/flan-t5-xxl",
137+
"version": "1.0.0",
138+
"min_sdk_version": "2.130.0",
139+
"training_supported": False,
140+
"incremental_training_supported": False,
141+
"hosting_ecr_specs": {
142+
"framework": "pytorch",
143+
"framework_version": "1.12.0",
144+
"py_version": "py38",
145+
"huggingface_transformers_version": "4.17.0",
146+
},
147+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
148+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
149+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
150+
"text2text-flan-t5-xxl-fp16.tar.gz",
151+
"hosting_prepacked_artifact_version": "1.0.0",
152+
"inference_vulnerable": False,
153+
"inference_dependencies": [
154+
"accelerate==0.16.0",
155+
"bitsandbytes==0.37.0",
156+
"filelock==3.9.0",
157+
"huggingface-hub==0.12.0",
158+
"regex==2022.7.9",
159+
"tokenizers==0.13.2",
160+
"transformers==4.26.0",
161+
],
162+
"inference_vulnerabilities": [],
163+
"training_vulnerable": False,
164+
"training_dependencies": [],
165+
"training_vulnerabilities": [],
166+
"deprecated": False,
167+
"inference_environment_variables": [
168+
{
169+
"name": "SAGEMAKER_PROGRAM",
170+
"type": "text",
171+
"default": "inference.py",
172+
"scope": "container",
173+
},
174+
{
175+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
176+
"type": "text",
177+
"default": "/opt/ml/model/code",
178+
"scope": "container",
179+
},
180+
{
181+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
182+
"type": "text",
183+
"default": "20",
184+
"scope": "container",
185+
},
186+
{
187+
"name": "MODEL_CACHE_ROOT",
188+
"type": "text",
189+
"default": "/opt/ml/model",
190+
"scope": "container",
191+
},
192+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
193+
{
194+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
195+
"type": "text",
196+
"default": "1",
197+
"scope": "container",
198+
},
199+
{
200+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
201+
"type": "text",
202+
"default": "3600",
203+
"scope": "container",
204+
},
205+
],
206+
"inference_vulnerable": False,
207+
"training_vulnerable": False,
208+
"deprecated": False,
209+
"default_training_instance_type": None,
210+
"supported_training_instance_types": [],
211+
"metrics": [],
212+
"default_inference_instance_type": "ml.g5.12xlarge",
213+
"supported_inference_instance_types": [
214+
"ml.g5.12xlarge",
215+
"ml.g5.24xlarge",
216+
"ml.p3.8xlarge",
217+
"ml.p3.16xlarge",
218+
"ml.g4dn.12xlarge",
219+
],
220+
},
17221
"mock-model-training-prepacked-script-key": {
18222
"model_id": "sklearn-classification-linear",
19223
"url": "https://scikit-learn.org/stable/",
@@ -134,7 +338,7 @@
134338
"scope": "container",
135339
},
136340
],
137-
}
341+
},
138342
}
139343

140344
PROTOTYPICAL_MODEL_SPECS_DICT = {
@@ -1219,6 +1423,7 @@
12191423
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
12201424
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
12211425
"training_prepacked_script_key": None,
1426+
"hosting_prepacked_artifact_key": None,
12221427
"hyperparameters": [
12231428
{
12241429
"name": "epochs",
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from mock.mock import patch
16+
17+
from sagemaker import model_uris
18+
19+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
20+
21+
22+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
23+
def test_jumpstart_combined_artifacts(patched_get_model_specs):
24+
25+
patched_get_model_specs.side_effect = get_special_model_spec
26+
27+
model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16"
28+
29+
uri = model_uris.retrieve(
30+
region="us-west-2",
31+
model_scope="inference",
32+
model_id=model_id_combined_model_artifact,
33+
model_version="*",
34+
)
35+
assert (
36+
uri == "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/"
37+
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
38+
)

0 commit comments

Comments
 (0)