Skip to content

Commit 06a8557

Browse files
authored
fix: gated training environment variables for jumpstart (#1338)
1 parent f6a1692 commit 06a8557

File tree

8 files changed

+790
-17
lines changed

8 files changed

+790
-17
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
JUMPSTART_DEFAULT_REGION_NAME,
19+
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
1920
)
2021
from sagemaker.jumpstart.enums import (
2122
JumpStartScriptScope,
2223
)
2324
from sagemaker.jumpstart.utils import (
25+
get_jumpstart_gated_content_bucket,
2426
verify_model_region_and_return_specs,
2527
)
2628
from sagemaker.session import Session
@@ -102,10 +104,89 @@ def _retrieve_default_environment_variables(
102104
elif script == JumpStartScriptScope.TRAINING and getattr(
103105
model_specs, "training_instance_type_variants", None
104106
):
105-
default_environment_variables.update(
106-
model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
107-
instance_type
108-
)
107+
instance_specific_environment_variables = model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301
108+
instance_type
109109
)
110110

111+
default_environment_variables.update(instance_specific_environment_variables)
112+
113+
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
114+
model_id=model_id,
115+
model_version=model_version,
116+
region=region,
117+
tolerate_vulnerable_model=tolerate_vulnerable_model,
118+
tolerate_deprecated_model=tolerate_deprecated_model,
119+
sagemaker_session=sagemaker_session,
120+
instance_type=instance_type,
121+
)
122+
123+
if gated_model_env_var is not None:
124+
default_environment_variables.update(
125+
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}
126+
)
127+
111128
return default_environment_variables
129+
130+
131+
def _retrieve_gated_model_uri_env_var_value(
132+
model_id: str,
133+
model_version: str,
134+
region: Optional[str] = None,
135+
tolerate_vulnerable_model: bool = False,
136+
tolerate_deprecated_model: bool = False,
137+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
138+
instance_type: Optional[str] = None,
139+
) -> Optional[str]:
140+
"""Retrieves the gated model env var URI matching the given arguments.
141+
142+
Args:
143+
model_id (str): JumpStart model ID of the JumpStart model for which to
144+
retrieve the gated model env var URI.
145+
model_version (str): Version of the JumpStart model for which to retrieve the
146+
gated model env var URI.
147+
region (Optional[str]): Region for which to retrieve the gated model env var URI.
148+
(Default: None).
149+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
150+
specifications should be tolerated (exception not raised). If False, raises an
151+
exception if the script used by this version of the model has dependencies with known
152+
security vulnerabilities. (Default: False).
153+
tolerate_deprecated_model (bool): True if deprecated versions of model
154+
specifications should be tolerated (exception not raised). If False, raises
155+
an exception if the version of the model is deprecated. (Default: False).
156+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
157+
object, used for SageMaker interactions. If not
158+
specified, one is created using the default AWS configuration
159+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
160+
instance_type (str): An instance type to optionally supply in order to get
161+
environment variables specific for the instance type.
162+
163+
Returns:
164+
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
165+
have gated training artifacts.
166+
167+
Raises:
168+
ValueError: If the model specs specified are invalid.
169+
"""
170+
171+
if region is None:
172+
region = JUMPSTART_DEFAULT_REGION_NAME
173+
174+
model_specs = verify_model_region_and_return_specs(
175+
model_id=model_id,
176+
version=model_version,
177+
scope=JumpStartScriptScope.TRAINING,
178+
region=region,
179+
tolerate_vulnerable_model=tolerate_vulnerable_model,
180+
tolerate_deprecated_model=tolerate_deprecated_model,
181+
sagemaker_session=sagemaker_session,
182+
)
183+
184+
s3_key: Optional[
185+
str
186+
] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301
187+
instance_type
188+
)
189+
if s3_key is None:
190+
return None
191+
192+
return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}"

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Dict, List, Optional, Union
1818
from sagemaker import (
19+
environment_variables,
1920
hyperparameters as hyperparameters_utils,
2021
image_uris,
2122
instance_types,
@@ -557,6 +558,18 @@ def _add_env_to_kwargs(
557558
) -> JumpStartEstimatorInitKwargs:
558559
"""Sets environment in kwargs based on default or override, returns full kwargs."""
559560

561+
extra_env_vars = environment_variables.retrieve_default(
562+
model_id=kwargs.model_id,
563+
model_version=kwargs.model_version,
564+
region=kwargs.region,
565+
include_aws_sdk_env_vars=False,
566+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
567+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
568+
sagemaker_session=kwargs.sagemaker_session,
569+
script=JumpStartScriptScope.TRAINING,
570+
instance_type=kwargs.instance_type,
571+
)
572+
560573
model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri(
561574
model_id=kwargs.model_id,
562575
model_version=kwargs.model_version,
@@ -568,12 +581,16 @@ def _add_env_to_kwargs(
568581
)
569582

570583
if model_package_artifact_uri:
571-
if kwargs.environment is None:
572-
kwargs.environment = {}
573-
kwargs.environment = {
574-
**{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri},
575-
**kwargs.environment,
576-
}
584+
extra_env_vars.update(
585+
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri}
586+
)
587+
588+
for key, value in extra_env_vars.items():
589+
update_dict_if_key_not_present(
590+
kwargs.environment,
591+
key,
592+
value,
593+
)
577594

578595
return kwargs
579596

src/sagemaker/jumpstart/types.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,16 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581

582582
return instance_family_environment_variables
583583

584+
def get_instance_specific_gated_model_key_env_var_value(
585+
self, instance_type: str
586+
) -> Optional[str]:
587+
"""Returns instance specific gated model env var s3 key.
588+
589+
Returns None if a model, instance type tuple does not have instance
590+
specific property.
591+
"""
592+
return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value")
593+
584594
def get_instance_specific_default_inference_instance_type(
585595
self, instance_type: str
586596
) -> Optional[str]:
@@ -901,10 +911,12 @@ def use_inference_script_uri(self) -> bool:
901911

902912
def use_training_model_artifact(self) -> bool:
903913
"""Returns True if the model should use a model uri when kicking off training job."""
904-
return (
905-
self.training_model_package_artifact_uris is None
906-
or len(self.training_model_package_artifact_uris) == 0
907-
)
914+
# gated model never use training model artifact
915+
if self.gated_bucket:
916+
return False
917+
918+
# otherwise, return true is a training model package is not set
919+
return len(self.training_model_package_artifact_uris or {}) == 0
908920

909921
def supports_incremental_training(self) -> bool:
910922
"""Returns True if the model supports incremental training."""

src/sagemaker/jumpstart/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,18 @@ def verify_model_region_and_return_specs(
611611

612612

613613
def update_dict_if_key_not_present(
614-
dict_to_update: dict, key_to_add: Any, value_to_add: Any
615-
) -> dict:
616-
"""If a key is not present in the dict, add the new (key, value) pair, and return dict."""
614+
dict_to_update: Optional[dict], key_to_add: Any, value_to_add: Any
615+
) -> Optional[dict]:
616+
"""If a key is not present in the dict, add the new (key, value) pair, and return dict.
617+
618+
If dict is empty, return None.
619+
"""
620+
if dict_to_update is None:
621+
dict_to_update = {}
617622
if key_to_add not in dict_to_update:
618623
dict_to_update[key_to_add] = value_to_add
624+
if dict_to_update == {}:
625+
dict_to_update = None
619626

620627
return dict_to_update
621628

0 commit comments

Comments
 (0)