Skip to content

Commit 95780d3

Browse files
evakravibencrabtree
authored andcommitted
chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied (aws#4485)
* fix: raise exception when no instance specific gated training env var available * chore: raise client exception if accept_eula flag is not set for gated models * chore: address flake8 errors * chore: emit warning when instance type is chosen with no gated training artifacts
1 parent 525e9ae commit 95780d3

File tree

8 files changed

+1388
-50
lines changed

8 files changed

+1388
-50
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart environment variables."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Optional
15+
from typing import Callable, Dict, Optional, Set
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
JUMPSTART_DEFAULT_REGION_NAME,
19+
JUMPSTART_LOGGER,
1920
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2021
)
2122
from sagemaker.jumpstart.enums import (
@@ -114,7 +115,9 @@ def _retrieve_default_environment_variables(
114115

115116
default_environment_variables.update(instance_specific_environment_variables)
116117

117-
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
118+
retrieve_gated_env_var_for_instance_type: Callable[
119+
[str], Optional[str]
120+
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
118121
model_id=model_id,
119122
model_version=model_version,
120123
hub_arn=hub_arn,
@@ -125,6 +128,33 @@ def _retrieve_default_environment_variables(
125128
instance_type=instance_type,
126129
)
127130

131+
gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
132+
instance_type
133+
)
134+
135+
if gated_model_env_var is None and model_specs.is_gated_model():
136+
137+
possible_env_vars: Set[str] = {
138+
retrieve_gated_env_var_for_instance_type(instance_type)
139+
for instance_type in model_specs.supported_training_instance_types
140+
}
141+
142+
# If all officially supported instance types have the same underlying artifact,
143+
# we can use this artifact with high confidence that it'll succeed with
144+
# an arbitrary instance.
145+
if len(possible_env_vars) == 1:
146+
gated_model_env_var = list(possible_env_vars)[0]
147+
148+
# If this model does not have 1 artifact for all supported instance types,
149+
# we cannot determine which artifact to use for an arbitrary instance.
150+
else:
151+
log_msg = (
152+
f"'{model_id}' does not support {instance_type} instance type"
153+
" for training. Please use one of the following instance types: "
154+
f"{', '.join(model_specs.supported_training_instance_types)}."
155+
)
156+
JUMPSTART_LOGGER.warning(log_msg)
157+
128158
if gated_model_env_var is not None:
129159
default_environment_variables.update(
130160
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from sagemaker.jumpstart.utils import (
6464
add_hub_arn_tags,
6565
add_jumpstart_model_id_version_tags,
66+
get_eula_message,
6667
update_dict_if_key_not_present,
6768
resolve_estimator_sagemaker_config_field,
6869
verify_model_region_and_return_specs,
@@ -617,6 +618,26 @@ def _add_env_to_kwargs(
617618
value,
618619
)
619620

621+
environment = getattr(kwargs, "environment", {}) or {}
622+
if (
623+
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
624+
and str(environment.get("accept_eula", "")).lower() != "true"
625+
):
626+
model_specs = verify_model_region_and_return_specs(
627+
model_id=kwargs.model_id,
628+
version=kwargs.model_version,
629+
region=kwargs.region,
630+
scope=JumpStartScriptScope.TRAINING,
631+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
632+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
633+
sagemaker_session=kwargs.sagemaker_session,
634+
)
635+
if model_specs.is_gated_model():
636+
raise ValueError(
637+
"Need to define ‘accept_eula'='true' within Environment. "
638+
f"{get_eula_message(model_specs, kwargs.region)}"
639+
)
640+
620641
return kwargs
621642

622643

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,10 @@ def use_training_model_artifact(self) -> bool:
991991
# otherwise, return true is a training model package is not set
992992
return len(self.training_model_package_artifact_uris or {}) == 0
993993

994+
def is_gated_model(self) -> bool:
995+
"""Returns True if the model has a EULA key or the model bucket is gated."""
996+
return self.gated_bucket or self.hosting_eula_key is not None
997+
994998
def supports_incremental_training(self) -> bool:
995999
"""Returns True if the model supports incremental training."""
9961000
return self.incremental_training_supported

src/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -491,21 +491,25 @@ def update_inference_tags_with_jumpstart_training_tags(
491491
return inference_tags
492492

493493

494+
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
495+
"""Returns EULA message to display if one is available, else empty string."""
496+
if model_specs.hosting_eula_key is None:
497+
return ""
498+
return (
499+
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
500+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
501+
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
502+
f"/{model_specs.hosting_eula_key} for terms of use."
503+
)
504+
505+
494506
def emit_logs_based_on_model_specs(
495507
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
496508
) -> None:
497509
"""Emits logs based on model specs and region."""
498510

499511
if model_specs.hosting_eula_key:
500-
constants.JUMPSTART_LOGGER.info(
501-
"Model '%s' requires accepting end-user license agreement (EULA). "
502-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
503-
model_specs.model_id,
504-
get_jumpstart_content_bucket(region=region),
505-
region,
506-
".cn" if region.startswith("cn-") else "",
507-
model_specs.hosting_eula_key,
508-
)
512+
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
509513

510514
full_version: str = model_specs.version
511515

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import environment_variables
21+
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket
2122
from sagemaker.jumpstart.enums import JumpStartModelType
2223

2324
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
@@ -207,6 +208,70 @@ def test_jumpstart_sdk_environment_variables(
207208
)
208209

209210

211+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
212+
def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs):
213+
214+
patched_get_model_specs.side_effect = get_special_model_spec
215+
216+
model_id = "gemma-model-1-artifact"
217+
region = "us-west-2"
218+
219+
assert {
220+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
221+
"huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz"
222+
} == environment_variables.retrieve_default(
223+
region=region,
224+
model_id=model_id,
225+
model_version="*",
226+
include_aws_sdk_env_vars=False,
227+
sagemaker_session=mock_session,
228+
instance_type="ml.p3.2xlarge",
229+
script="training",
230+
)
231+
232+
233+
@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER")
234+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
235+
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(
236+
patched_get_model_specs, patched_jumpstart_logger
237+
):
238+
239+
patched_get_model_specs.side_effect = get_special_model_spec
240+
241+
model_id = "gemma-model"
242+
region = "us-west-2"
243+
244+
assert {} == environment_variables.retrieve_default(
245+
region=region,
246+
model_id=model_id,
247+
model_version="*",
248+
include_aws_sdk_env_vars=False,
249+
sagemaker_session=mock_session,
250+
instance_type="ml.p3.2xlarge",
251+
script="training",
252+
)
253+
254+
patched_jumpstart_logger.warning.assert_called_once_with(
255+
"'gemma-model' does not support ml.p3.2xlarge instance type for "
256+
"training. Please use one of the following instance types: "
257+
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
258+
)
259+
260+
# assert that supported instance types succeed
261+
assert {
262+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
263+
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
264+
} == environment_variables.retrieve_default(
265+
region=region,
266+
model_id=model_id,
267+
model_version="*",
268+
include_aws_sdk_env_vars=False,
269+
sagemaker_session=mock_session,
270+
instance_type="ml.g5.24xlarge",
271+
script="training",
272+
)
273+
274+
210275
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
211276
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
212277

0 commit comments

Comments
 (0)