Skip to content

Commit 427c22e

Browse files
committed
feat: jumpstart EULA models (initial commit)
1 parent 53fe9b7 commit 427c22e

File tree

5 files changed

+81
-10
lines changed

5 files changed

+81
-10
lines changed

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart model uris."""
1414
from __future__ import absolute_import
15+
import logging
1516
import os
1617
from typing import Optional
1718

@@ -26,6 +27,23 @@
2627
get_jumpstart_content_bucket,
2728
verify_model_region_and_return_specs,
2829
)
30+
from sagemaker.s3_utils import parse_s3_url
31+
32+
logger = logging.getLogger(__name__)
33+
34+
35+
_HOSTING_MODEL_ARTIFACT_URIS_PRIORITY_LIST = [
36+
"hosting_prepacked_artifact_uri",
37+
"hosting_artifact_uri",
38+
]
39+
_HOSTING_MODEL_ARTIFACT_KEYS_PRIORITY_LIST = [
40+
"hosting_prepacked_artifact_key",
41+
"hosting_artifact_key",
42+
]
43+
_MODEL_ARTIFACT_PRIORITY_LIST = [
44+
*_HOSTING_MODEL_ARTIFACT_URIS_PRIORITY_LIST,
45+
*_HOSTING_MODEL_ARTIFACT_KEYS_PRIORITY_LIST,
46+
]
2947

3048

3149
def _retrieve_model_uri(
@@ -76,18 +94,35 @@ def _retrieve_model_uri(
7694
tolerate_deprecated_model=tolerate_deprecated_model,
7795
)
7896

97+
bucket_override = os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
98+
99+
if bucket_override:
100+
logger.info("Using JumpStart bucket override for model URI: '%s'", bucket_override)
101+
79102
if model_scope == JumpStartScriptScope.INFERENCE:
80-
model_artifact_key = (
81-
getattr(model_specs, "hosting_prepacked_artifact_key", None)
82-
or model_specs.hosting_artifact_key
83-
)
103+
104+
model_artifact_key = None
105+
106+
for model_artifact_field_name in _MODEL_ARTIFACT_PRIORITY_LIST:
107+
field_value = getattr(model_specs, model_artifact_field_name, None)
108+
if field_value is not None:
109+
if model_artifact_field_name in _HOSTING_MODEL_ARTIFACT_URIS_PRIORITY_LIST:
110+
bucket, key = parse_s3_url(field_value)
111+
bucket_to_use = bucket_override or bucket
112+
field_value_to_use = f"s3://{bucket_to_use}/{key}"
113+
114+
return field_value_to_use
115+
116+
model_artifact_key = field_value
117+
break
118+
119+
if model_artifact_key is None:
120+
raise RuntimeError(f"Unable to find model artifact for '{model_id}'.")
84121

85122
elif model_scope == JumpStartScriptScope.TRAINING:
86123
model_artifact_key = model_specs.training_artifact_key
87124

88-
bucket = os.environ.get(
89-
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
90-
) or get_jumpstart_content_bucket(region)
125+
bucket = bucket_override or get_jumpstart_content_bucket(region)
91126

92127
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
93128

src/sagemaker/jumpstart/artifacts/script_uris.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart script uris."""
1414
from __future__ import absolute_import
15+
import logging
1516
import os
1617
from typing import Optional
1718
from sagemaker.jumpstart.constants import (
@@ -25,6 +26,9 @@
2526
get_jumpstart_content_bucket,
2627
verify_model_region_and_return_specs,
2728
)
29+
from sagemaker.s3_utils import parse_s3_url
30+
31+
logger = logging.getLogger(__name__)
2832

2933

3034
def _retrieve_script_uri(
@@ -76,16 +80,27 @@ def _retrieve_script_uri(
7680
tolerate_deprecated_model=tolerate_deprecated_model,
7781
)
7882

83+
bucket_override = os.environ.get(ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE)
84+
85+
if bucket_override:
86+
logger.info("Using JumpStart bucket override for script URI: '%s'", bucket_override)
87+
7988
if script_scope == JumpStartScriptScope.INFERENCE:
89+
script_uri = getattr(model_specs, "hosting_script_uri", None)
90+
if script_uri is not None:
91+
bucket, key = parse_s3_url(script_uri)
92+
bucket_to_use = bucket_override or bucket
93+
script_uri_to_use = f"s3://{bucket_to_use}/{key}"
94+
95+
return script_uri_to_use
96+
8097
model_script_key = model_specs.hosting_script_key
8198
elif script_scope == JumpStartScriptScope.TRAINING:
8299
model_script_key = (
83100
getattr(model_specs, "training_prepacked_script_key") or model_specs.training_script_key
84101
)
85102

86-
bucket = os.environ.get(
87-
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE
88-
) or get_jumpstart_content_bucket(region)
103+
bucket = bucket_override or get_jumpstart_content_bucket(region)
89104

90105
script_s3_uri = f"s3://{bucket}/{model_script_key}"
91106

src/sagemaker/jumpstart/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
351351
"inference_enable_network_isolation",
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354+
"eula_model",
355+
"hosting_prepacked_artifact_uri",
356+
"hosting_artifact_uri",
357+
"hosting_script_uri",
354358
]
355359

356360
def __init__(self, spec: Dict[str, Any]):
@@ -419,6 +423,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
419423
)
420424
self.resource_name_base: bool = json_obj.get("resource_name_base")
421425

426+
self.eula_model: bool = json_obj.get("eula_model", False)
427+
self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get(
428+
"hosting_prepacked_artifact_uri"
429+
)
430+
self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri")
431+
self.hosting_script_uri: Optional[str] = json_obj.get("hosting_script_uri")
432+
422433
if self.training_supported:
423434
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
424435
json_obj["training_ecr_specs"]

src/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405+
if model_specs.eula_model:
406+
LOGGER.info(
407+
"Using model with end-user license agreement (EULA). "
408+
"Deploying this model requires accepting EULA terms."
409+
)
410+
405411
if model_specs.deprecated:
406412
if not tolerate_deprecated_model:
407413
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,11 +2294,15 @@
22942294
"py_version": "py3",
22952295
},
22962296
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
2297+
"hosting_artifact_uri": None,
22972298
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
22982299
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
2300+
"hosting_script_uri": None,
22992301
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002302
"training_prepacked_script_key": None,
23012303
"hosting_prepacked_artifact_key": None,
2304+
"hosting_prepacked_artifact_uri": None,
2305+
"eula_model": False,
23022306
"hyperparameters": [
23032307
{
23042308
"name": "epochs",

0 commit comments

Comments
 (0)