|
16 | 16 | from sagemaker.jumpstart.constants import (
|
17 | 17 | DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
|
18 | 18 | JUMPSTART_DEFAULT_REGION_NAME,
|
| 19 | + SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, |
19 | 20 | )
|
20 | 21 | from sagemaker.jumpstart.enums import (
|
21 | 22 | JumpStartScriptScope,
|
22 | 23 | )
|
23 | 24 | from sagemaker.jumpstart.utils import (
|
| 25 | + get_jumpstart_gated_content_bucket, |
24 | 26 | verify_model_region_and_return_specs,
|
25 | 27 | )
|
26 | 28 | from sagemaker.session import Session
|
@@ -102,10 +104,89 @@ def _retrieve_default_environment_variables(
|
102 | 104 | elif script == JumpStartScriptScope.TRAINING and getattr(
|
103 | 105 | model_specs, "training_instance_type_variants", None
|
104 | 106 | ):
|
105 |
| - default_environment_variables.update( |
106 |
| - model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301 |
107 |
| - instance_type |
108 |
| - ) |
| 107 | + instance_specific_environment_variables = model_specs.training_instance_type_variants.get_instance_specific_environment_variables( # noqa E501 # pylint: disable=c0301 |
| 108 | + instance_type |
109 | 109 | )
|
110 | 110 |
|
| 111 | + default_environment_variables.update(instance_specific_environment_variables) |
| 112 | + |
| 113 | + gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( |
| 114 | + model_id=model_id, |
| 115 | + model_version=model_version, |
| 116 | + region=region, |
| 117 | + tolerate_vulnerable_model=tolerate_vulnerable_model, |
| 118 | + tolerate_deprecated_model=tolerate_deprecated_model, |
| 119 | + sagemaker_session=sagemaker_session, |
| 120 | + instance_type=instance_type, |
| 121 | + ) |
| 122 | + |
| 123 | + if gated_model_env_var is not None: |
| 124 | + default_environment_variables.update( |
| 125 | + {SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var} |
| 126 | + ) |
| 127 | + |
111 | 128 | return default_environment_variables
|
| 129 | + |
| 130 | + |
| 131 | +def _retrieve_gated_model_uri_env_var_value( |
| 132 | + model_id: str, |
| 133 | + model_version: str, |
| 134 | + region: Optional[str] = None, |
| 135 | + tolerate_vulnerable_model: bool = False, |
| 136 | + tolerate_deprecated_model: bool = False, |
| 137 | + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
| 138 | + instance_type: Optional[str] = None, |
| 139 | +) -> Optional[str]: |
| 140 | + """Retrieves the gated model env var URI matching the given arguments. |
| 141 | +
|
| 142 | + Args: |
| 143 | + model_id (str): JumpStart model ID of the JumpStart model for which to |
| 144 | + retrieve the gated model env var URI. |
| 145 | + model_version (str): Version of the JumpStart model for which to retrieve the |
| 146 | + gated model env var URI. |
| 147 | + region (Optional[str]): Region for which to retrieve the gated model env var URI. |
| 148 | + (Default: None). |
| 149 | + tolerate_vulnerable_model (bool): True if vulnerable versions of model |
| 150 | + specifications should be tolerated (exception not raised). If False, raises an |
| 151 | + exception if the script used by this version of the model has dependencies with known |
| 152 | + security vulnerabilities. (Default: False). |
| 153 | + tolerate_deprecated_model (bool): True if deprecated versions of model |
| 154 | + specifications should be tolerated (exception not raised). If False, raises |
| 155 | + an exception if the version of the model is deprecated. (Default: False). |
| 156 | + sagemaker_session (sagemaker.session.Session): A SageMaker Session |
| 157 | + object, used for SageMaker interactions. If not |
| 158 | + specified, one is created using the default AWS configuration |
| 159 | + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). |
| 160 | + instance_type (str): An instance type to optionally supply in order to get |
| 161 | + environment variables specific for the instance type. |
| 162 | +
|
| 163 | + Returns: |
| 164 | + Optional[str]: the s3 URI to use for the environment variable, or None if the model does not |
| 165 | + have gated training artifacts. |
| 166 | +
|
| 167 | + Raises: |
| 168 | + ValueError: If the model specs specified are invalid. |
| 169 | + """ |
| 170 | + |
| 171 | + if region is None: |
| 172 | + region = JUMPSTART_DEFAULT_REGION_NAME |
| 173 | + |
| 174 | + model_specs = verify_model_region_and_return_specs( |
| 175 | + model_id=model_id, |
| 176 | + version=model_version, |
| 177 | + scope=JumpStartScriptScope.TRAINING, |
| 178 | + region=region, |
| 179 | + tolerate_vulnerable_model=tolerate_vulnerable_model, |
| 180 | + tolerate_deprecated_model=tolerate_deprecated_model, |
| 181 | + sagemaker_session=sagemaker_session, |
| 182 | + ) |
| 183 | + |
| 184 | + s3_key: Optional[ |
| 185 | + str |
| 186 | + ] = model_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( # noqa E501 # pylint: disable=c0301 |
| 187 | + instance_type |
| 188 | + ) |
| 189 | + if s3_key is None: |
| 190 | + return None |
| 191 | + |
| 192 | + return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" |
0 commit comments