Skip to content

Commit 2367d16

Browse files
committed
feat: jumpstart EULA models (initial commit)
1 parent 53fe9b7 commit 2367d16

File tree

8 files changed

+154
-1
lines changed

8 files changed

+154
-1
lines changed

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@
5252
_retrieve_supported_accept_types,
5353
_retrieve_supported_content_types,
5454
)
55+
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functions for obtaining JumpStart model packages."""
14+
from __future__ import absolute_import
15+
from copy import deepcopy
16+
from typing import Dict, List, Optional
17+
from sagemaker.jumpstart.constants import (
18+
JUMPSTART_DEFAULT_REGION_NAME,
19+
)
20+
from sagemaker.jumpstart.enums import (
21+
JumpStartScriptScope,
22+
)
23+
from sagemaker.jumpstart.utils import (
24+
verify_model_region_and_return_specs,
25+
)
26+
27+
28+
def _retrieve_model_package_arn(
29+
model_id: str,
30+
model_version: str,
31+
region: Optional[str],
32+
tolerate_vulnerable_model: bool = False,
33+
tolerate_deprecated_model: bool = False,
34+
) -> Optional[str]:
35+
"""Retrieves associated model pacakge arn for the model.
36+
37+
Args:
38+
model_id (str): JumpStart model ID of the JumpStart model for which to
39+
retrieve the model package arn.
40+
model_version (str): Version of the JumpStart model for which to retrieve the
41+
model package arn.
42+
region (Optional[str]): Region for which to retrieve the model package arn.
43+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
44+
specifications should be tolerated (exception not raised). If False, raises an
45+
exception if the script used by this version of the model has dependencies with known
46+
security vulnerabilities. (Default: False).
47+
tolerate_deprecated_model (bool): True if deprecated versions of model
48+
specifications should be tolerated (exception not raised). If False, raises
49+
an exception if the version of the model is deprecated. (Default: False).
50+
51+
Returns:
52+
list: the model package arn to use for the model or None.
53+
"""
54+
55+
if region is None:
56+
region = JUMPSTART_DEFAULT_REGION_NAME
57+
58+
model_specs = verify_model_region_and_return_specs(
59+
model_id=model_id,
60+
version=model_version,
61+
scope=JumpStartScriptScope.TRAINING,
62+
region=region,
63+
tolerate_vulnerable_model=tolerate_vulnerable_model,
64+
tolerate_deprecated_model=tolerate_deprecated_model,
65+
)
66+
67+
return model_specs.model_package_arn

src/sagemaker/jumpstart/factory/model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_model_supports_prepacked_inference,
2626
_retrieve_model_init_kwargs,
2727
_retrieve_model_deploy_kwargs,
28+
_retrieve_model_package_arn,
2829
)
2930
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3031
from sagemaker.jumpstart.constants import (
@@ -289,6 +290,21 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
289290
return kwargs
290291

291292

293+
def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
294+
"""Sets model package arn based on default or override, returns full kwargs."""
295+
296+
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
297+
model_id=kwargs.model_id,
298+
model_version=kwargs.model_version,
299+
region=kwargs.region,
300+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
301+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
302+
)
303+
304+
kwargs.model_package_arn = model_package_arn
305+
return kwargs
306+
307+
292308
def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
293309
"""Sets extra kwargs based on default or override, returns full kwargs."""
294310

@@ -471,6 +487,7 @@ def get_init_kwargs(
471487
container_log_level: Optional[Union[int, PipelineVariable]] = None,
472488
dependencies: Optional[List[str]] = None,
473489
git_config: Optional[Dict[str, str]] = None,
490+
model_package_arn: Optional[str] = None,
474491
) -> JumpStartModelInitKwargs:
475492
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
476493

@@ -498,6 +515,7 @@ def get_init_kwargs(
498515
git_config=git_config,
499516
tolerate_deprecated_model=tolerate_deprecated_model,
500517
tolerate_vulnerable_model=tolerate_vulnerable_model,
518+
model_package_arn=model_package_arn,
501519
)
502520

503521
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
@@ -526,4 +544,6 @@ def get_init_kwargs(
526544
model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs)
527545
model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs)
528546

547+
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
548+
529549
return model_init_kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import absolute_import
1616
import logging
17+
import re
1718

1819
from typing import Dict, List, Optional, Union
1920
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -30,7 +31,7 @@
3031
)
3132
from sagemaker.jumpstart.utils import is_valid_model_id
3233
from sagemaker.utils import stringify_object
33-
from sagemaker.model import Model
34+
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model, ModelPackage
3435
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3536
from sagemaker.predictor import PredictorBase
3637
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
@@ -71,6 +72,7 @@ def __init__(
7172
container_log_level: Optional[Union[int, PipelineVariable]] = None,
7273
dependencies: Optional[List[str]] = None,
7374
git_config: Optional[Dict[str, str]] = None,
75+
model_package_arn: Optional[str] = None,
7476
):
7577
"""Initializes a ``JumpStartModel``.
7678
@@ -249,6 +251,9 @@ def __init__(
249251
>>> 'branch': 'test-branch-git-config',
250252
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
251253
254+
model_package_arn (Optional[str]): An existing SageMaker Model Package arn,
255+
can be just the name if your account owns the Model Package.
256+
``model_data`` is not required. (Default: None).
252257
Raises:
253258
ValueError: If the model ID is not recognized by JumpStart.
254259
"""
@@ -291,6 +296,7 @@ def _is_valid_model_id_hook():
291296
container_log_level=container_log_level,
292297
dependencies=dependencies,
293298
git_config=git_config,
299+
model_package_arn=model_package_arn,
294300
)
295301

296302
self.orig_predictor_cls = predictor_cls
@@ -301,9 +307,49 @@ def _is_valid_model_id_hook():
301307
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
302308
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
303309
self.region = model_init_kwargs.region
310+
self.model_package_arn = model_init_kwargs.model_package_arn
304311

305312
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
306313

314+
def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
315+
"""Create a SageMaker Model Entity
316+
317+
Args:
318+
args: Positional arguments coming from the caller. This class does not require
319+
any so they are ignored.
320+
321+
kwargs: Keyword arguments coming from the caller. This class does not require
322+
any so they are ignored.
323+
"""
324+
if self.model_package_arn:
325+
# When a ModelPackageArn is provided we just create the Model
326+
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
327+
if match:
328+
model_package_name = match.group(3)
329+
else:
330+
# model_package_arn can be just the name if your account owns the Model Package
331+
model_package_name = self.model_package_arn
332+
container_def = {"ModelPackageName": self.model_package_arn}
333+
334+
if self.env != {}:
335+
container_def["Environment"] = self.env
336+
337+
if self.name is None:
338+
self._base_name = model_package_name
339+
340+
self._set_model_name_if_needed()
341+
342+
self.sagemaker_session.create_model(
343+
self.name,
344+
self.role,
345+
container_def,
346+
vpc_config=self.vpc_config,
347+
enable_network_isolation=self.enable_network_isolation(),
348+
tags=kwargs.get("tags"),
349+
)
350+
else:
351+
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)
352+
307353
def deploy(
308354
self,
309355
initial_instance_count: Optional[int] = None,

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
351351
"inference_enable_network_isolation",
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354+
"eula_model",
355+
"model_package_arn",
354356
]
355357

356358
def __init__(self, spec: Dict[str, Any]):
@@ -419,6 +421,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
419421
)
420422
self.resource_name_base: bool = json_obj.get("resource_name_base")
421423

424+
self.eula_model: bool = json_obj.get("eula_model", False)
425+
426+
self.model_package_arn: Optional[str] = json_obj.get("model_package_arn")
427+
422428
if self.training_supported:
423429
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
424430
json_obj["training_ecr_specs"]
@@ -574,6 +580,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
574580
"container_log_level",
575581
"dependencies",
576582
"git_config",
583+
"model_package_arn",
577584
]
578585

579586
SERIALIZATION_EXCLUSION_SET = {
@@ -583,6 +590,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
583590
"tolerate_vulnerable_model",
584591
"tolerate_deprecated_model",
585592
"region",
593+
"model_package_arn",
586594
}
587595

588596
def __init__(
@@ -610,6 +618,7 @@ def __init__(
610618
git_config: Optional[Dict[str, str]] = None,
611619
tolerate_vulnerable_model: Optional[bool] = None,
612620
tolerate_deprecated_model: Optional[bool] = None,
621+
model_package_arn: Optional[str] = None,
613622
) -> None:
614623
"""Instantiates JumpStartModelInitKwargs object."""
615624

@@ -636,6 +645,7 @@ def __init__(
636645
self.git_config = git_config
637646
self.tolerate_deprecated_model = tolerate_deprecated_model
638647
self.tolerate_vulnerable_model = tolerate_vulnerable_model
648+
self.model_package_arn = model_package_arn
639649

640650

641651
class JumpStartModelDeployKwargs(JumpStartKwargs):

src/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,12 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405+
if model_specs.eula_model:
406+
LOGGER.info(
407+
"Using model with end-user license agreement (EULA). "
408+
"Deploying this model requires accepting EULA terms."
409+
)
410+
405411
if model_specs.deprecated:
406412
if not tolerate_deprecated_model:
407413
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@
22992299
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
2302+
"model_package_arn": None,
2303+
"eula_model": False,
23022304
"hyperparameters": [
23032305
{
23042306
"name": "epochs",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
345345
"tolerate_vulnerable_model",
346346
"tolerate_deprecated_model",
347347
"instance_type",
348+
"model_package_arn",
348349
}
349350
assert parent_class_init_args - js_class_init_args == init_args_to_skip
350351

0 commit comments

Comments
 (0)