Skip to content

Commit 6ca6627

Browse files
committed
feat: instance type variants for environment variables
1 parent 3f42c15 commit 6ca6627

File tree

7 files changed

+404
-50
lines changed

7 files changed

+404
-50
lines changed

src/sagemaker/environment_variables.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart import utils as jumpstart_utils
2121
from sagemaker.jumpstart import artifacts
2222
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
23+
from sagemaker.jumpstart.enums import JumpStartScriptScope
2324
from sagemaker.session import Session
2425

2526
logger = logging.getLogger(__name__)
@@ -33,6 +34,8 @@ def retrieve_default(
3334
tolerate_deprecated_model: bool = False,
3435
include_aws_sdk_env_vars: bool = True,
3536
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37+
instance_type: Optional[str] = None,
38+
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
3639
) -> Dict[str, str]:
3740
"""Retrieves the default container environment variables for the model matching the arguments.
3841
@@ -58,6 +61,9 @@ def retrieve_default(
5861
object, used for SageMaker interactions. If not
5962
specified, one is created using the default AWS configuration
6063
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get environment variables
65+
specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables.
6167
Returns:
6268
dict: The variables to use for the model.
6369
@@ -78,4 +84,6 @@ def retrieve_default(
7884
tolerate_deprecated_model,
7985
include_aws_sdk_env_vars,
8086
sagemaker_session=sagemaker_session,
87+
instance_type=instance_type,
88+
script=script,
8189
)

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def _retrieve_default_environment_variables(
3434
tolerate_deprecated_model: bool = False,
3535
include_aws_sdk_env_vars: bool = True,
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
37+
instance_type: Optional[str] = None,
38+
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
3739
) -> Dict[str, str]:
3840
"""Retrieves the inference environment variables for the model matching the given arguments.
3941
@@ -59,6 +61,9 @@ def _retrieve_default_environment_variables(
5961
object, used for SageMaker interactions. If not
6062
specified, one is created using the default AWS configuration
6163
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
instance_type (str): An instance type to optionally supply in order to get environment variables
65+
specific for the instance type.
66+
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables.
6267
Returns:
6368
dict: the inference environment variables to use for the model.
6469
"""
@@ -69,17 +74,37 @@ def _retrieve_default_environment_variables(
6974
model_specs = verify_model_region_and_return_specs(
7075
model_id=model_id,
7176
version=model_version,
72-
scope=JumpStartScriptScope.INFERENCE,
77+
scope=script,
7378
region=region,
7479
tolerate_vulnerable_model=tolerate_vulnerable_model,
7580
tolerate_deprecated_model=tolerate_deprecated_model,
7681
sagemaker_session=sagemaker_session,
7782
)
7883

7984
default_environment_variables: Dict[str, str] = {}
80-
for environment_variable in model_specs.inference_environment_variables:
81-
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
82-
default_environment_variables[environment_variable.name] = str(
83-
environment_variable.default
85+
if script == JumpStartScriptScope.INFERENCE:
86+
for environment_variable in model_specs.inference_environment_variables:
87+
if include_aws_sdk_env_vars or environment_variable.required_for_model_class:
88+
default_environment_variables[environment_variable.name] = str(
89+
environment_variable.default
90+
)
91+
92+
if instance_type:
93+
if script == JumpStartScriptScope.INFERENCE and getattr(
94+
model_specs, "hosting_instance_type_variants", None
95+
):
96+
default_environment_variables.update(
97+
model_specs.hosting_instance_type_variants.get_instance_specific_environment_variables(
98+
instance_type
99+
)
100+
)
101+
elif script == JumpStartScriptScope.TRAINING and getattr(
102+
model_specs, "training_instance_type_variants", None
103+
):
104+
default_environment_variables.update(
105+
model_specs.training_instance_type_variants.get_instance_specific_environment_variables(
106+
instance_type
107+
)
84108
)
109+
85110
return default_environment_variables

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
286286
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
287287
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
288288
sagemaker_session=kwargs.sagemaker_session,
289+
script=JumpStartScriptScope.INFERENCE,
290+
instance_type=kwargs.instance_type,
289291
)
290292

291293
for key, value in extra_env_vars.items():

src/sagemaker/jumpstart/types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,35 @@ def to_json(self) -> Dict[str, Any]:
346346
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
347347
return json_obj
348348

349+
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
350+
"""Returns instance specific environment variables.
351+
352+
Not all models and images have instance specific environment variables.
353+
"""
354+
355+
if self.variants is None:
356+
return {}
357+
358+
instance_specific_environment_variables: dict = (
359+
self.variants.get(instance_type, {})
360+
.get("properties", {})
361+
.get("environment_variables", {})
362+
)
363+
364+
instance_type_family = get_instance_type_family(instance_type)
365+
366+
instance_family_environment_variables: dict = (
367+
self.variants.get(instance_type_family, {})
368+
.get("properties", {})
369+
.get("environment_variables", {})
370+
if instance_type_family not in {"", None}
371+
else {}
372+
)
373+
374+
instance_family_environment_variables.update(instance_specific_environment_variables)
375+
376+
return instance_family_environment_variables
377+
349378
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
350379
"""Returns image uri from instance type and region.
351380

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker import environment_variables
2121

22-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
22+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2323

2424
mock_client = boto3.client("s3")
2525
mock_session = Mock(s3_client=mock_client)
@@ -175,3 +175,76 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
175175
model_id=model_id,
176176
include_aws_sdk_env_vars=False,
177177
)
178+
179+
180+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
181+
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
182+
183+
patched_get_model_specs.side_effect = get_special_model_spec
184+
185+
model_id = "env-var-variant-model"
186+
region = "us-west-2"
187+
188+
# assert that we can override default environment variables
189+
vars = environment_variables.retrieve_default(
190+
region=region,
191+
model_id=model_id,
192+
model_version="*",
193+
include_aws_sdk_env_vars=False,
194+
sagemaker_session=mock_session,
195+
instance_type="ml.g5.48xlarge",
196+
)
197+
assert vars == {
198+
"ENDPOINT_SERVER_TIMEOUT": "3600",
199+
"HF_MODEL_ID": "/opt/ml/model",
200+
"MAX_INPUT_LENGTH": "1024",
201+
"MAX_TOTAL_TOKENS": "2048",
202+
"MODEL_CACHE_ROOT": "/opt/ml/model",
203+
"SAGEMAKER_ENV": "1",
204+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
205+
"SAGEMAKER_PROGRAM": "inference.py",
206+
"SM_NUM_GPUS": "80",
207+
}
208+
209+
# assert that we can add environment variables
210+
vars = environment_variables.retrieve_default(
211+
region=region,
212+
model_id=model_id,
213+
model_version="*",
214+
include_aws_sdk_env_vars=False,
215+
sagemaker_session=mock_session,
216+
instance_type="ml.p4d.24xlarge",
217+
)
218+
assert vars == {
219+
"ENDPOINT_SERVER_TIMEOUT": "3600",
220+
"HF_MODEL_ID": "/opt/ml/model",
221+
"MAX_INPUT_LENGTH": "1024",
222+
"MAX_TOTAL_TOKENS": "2048",
223+
"MODEL_CACHE_ROOT": "/opt/ml/model",
224+
"SAGEMAKER_ENV": "1",
225+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
226+
"SAGEMAKER_PROGRAM": "inference.py",
227+
"SM_NUM_GPUS": "8",
228+
"YODEL": "NACEREMA",
229+
}
230+
231+
# assert that we can return default env variables for unrecognized instance
232+
vars = environment_variables.retrieve_default(
233+
region=region,
234+
model_id=model_id,
235+
model_version="*",
236+
include_aws_sdk_env_vars=False,
237+
sagemaker_session=mock_session,
238+
instance_type="ml.p002.xlarge",
239+
)
240+
assert vars == {
241+
"ENDPOINT_SERVER_TIMEOUT": "3600",
242+
"HF_MODEL_ID": "/opt/ml/model",
243+
"MAX_INPUT_LENGTH": "1024",
244+
"MAX_TOTAL_TOKENS": "2048",
245+
"MODEL_CACHE_ROOT": "/opt/ml/model",
246+
"SAGEMAKER_ENV": "1",
247+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
248+
"SAGEMAKER_PROGRAM": "inference.py",
249+
"SM_NUM_GPUS": "8",
250+
}

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,166 @@
1414

1515

1616
SPECIAL_MODEL_SPECS_DICT = {
17+
"env-var-variant-model": {
18+
"model_id": "huggingface-llm-falcon-180b-bf16",
19+
"url": "https://huggingface.co/tiiuae/falcon-180B",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.175.0",
22+
"training_supported": False,
23+
"incremental_training_supported": False,
24+
"hosting_ecr_specs": {
25+
"framework": "huggingface-llm",
26+
"framework_version": "0.9.3",
27+
"py_version": "py39",
28+
"huggingface_transformers_version": "4.29.2",
29+
},
30+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz",
31+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz",
32+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack"
33+
"-huggingface-llm-falcon-180b-bf16.tar.gz",
34+
"hosting_prepacked_artifact_version": "1.0.1",
35+
"hosting_use_script_uri": False,
36+
"inference_vulnerable": False,
37+
"inference_dependencies": [],
38+
"inference_vulnerabilities": [],
39+
"training_vulnerable": False,
40+
"training_dependencies": [],
41+
"training_vulnerabilities": [],
42+
"deprecated": False,
43+
"inference_environment_variables": [
44+
{
45+
"name": "SAGEMAKER_PROGRAM",
46+
"type": "text",
47+
"default": "inference.py",
48+
"scope": "container",
49+
"required_for_model_class": True,
50+
},
51+
{
52+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
53+
"type": "text",
54+
"default": "/opt/ml/model/code",
55+
"scope": "container",
56+
"required_for_model_class": False,
57+
},
58+
{
59+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
60+
"type": "text",
61+
"default": "20",
62+
"scope": "container",
63+
"required_for_model_class": False,
64+
},
65+
{
66+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
67+
"type": "text",
68+
"default": "3600",
69+
"scope": "container",
70+
"required_for_model_class": False,
71+
},
72+
{
73+
"name": "ENDPOINT_SERVER_TIMEOUT",
74+
"type": "int",
75+
"default": 3600,
76+
"scope": "container",
77+
"required_for_model_class": True,
78+
},
79+
{
80+
"name": "MODEL_CACHE_ROOT",
81+
"type": "text",
82+
"default": "/opt/ml/model",
83+
"scope": "container",
84+
"required_for_model_class": True,
85+
},
86+
{
87+
"name": "SAGEMAKER_ENV",
88+
"type": "text",
89+
"default": "1",
90+
"scope": "container",
91+
"required_for_model_class": True,
92+
},
93+
{
94+
"name": "HF_MODEL_ID",
95+
"type": "text",
96+
"default": "/opt/ml/model",
97+
"scope": "container",
98+
"required_for_model_class": True,
99+
},
100+
{
101+
"name": "SM_NUM_GPUS",
102+
"type": "text",
103+
"default": "8",
104+
"scope": "container",
105+
"required_for_model_class": True,
106+
},
107+
{
108+
"name": "MAX_INPUT_LENGTH",
109+
"type": "text",
110+
"default": "1024",
111+
"scope": "container",
112+
"required_for_model_class": True,
113+
},
114+
{
115+
"name": "MAX_TOTAL_TOKENS",
116+
"type": "text",
117+
"default": "2048",
118+
"scope": "container",
119+
"required_for_model_class": True,
120+
},
121+
{
122+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
123+
"type": "int",
124+
"default": 1,
125+
"scope": "container",
126+
"required_for_model_class": True,
127+
},
128+
],
129+
"metrics": [],
130+
"default_inference_instance_type": "ml.p4de.24xlarge",
131+
"supported_inference_instance_types": ["ml.p4de.24xlarge"],
132+
"model_kwargs": {},
133+
"deploy_kwargs": {
134+
"model_data_download_timeout": 3600,
135+
"container_startup_health_check_timeout": 3600,
136+
},
137+
"predictor_specs": {
138+
"supported_content_types": ["application/json"],
139+
"supported_accept_types": ["application/json"],
140+
"default_content_type": "application/json",
141+
"default_accept_type": "application/json",
142+
},
143+
"inference_volume_size": 512,
144+
"inference_enable_network_isolation": True,
145+
"validation_supported": False,
146+
"fine_tuning_supported": False,
147+
"resource_name_base": "hf-llm-falcon-180b-bf16",
148+
"hosting_instance_type_variants": {
149+
"regional_aliases": {
150+
"us-west-2": {
151+
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
152+
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
153+
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
154+
}
155+
},
156+
"variants": {
157+
"g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
158+
"g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
159+
"local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
160+
"p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
161+
"p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
162+
"p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
163+
"p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
164+
"p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
165+
"p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}},
166+
"ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}},
167+
"ml.p4d.24xlarge": {
168+
"properties": {
169+
"environment_variables": {
170+
"YODEL": "NACEREMA",
171+
}
172+
}
173+
},
174+
},
175+
},
176+
},
17177
"variant-model": {
18178
"model_id": "pytorch-ic-mobilenet-v2",
19179
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",

0 commit comments

Comments
 (0)