Skip to content

Commit 0a5b450

Browse files
authored
feat: jumpstart EULA models (#3999)
1 parent b8f659a commit 0a5b450

File tree

10 files changed

+212
-2
lines changed

10 files changed

+212
-2
lines changed

src/sagemaker/base_predictor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def predict(
133133
target_model=None,
134134
target_variant=None,
135135
inference_id=None,
136+
custom_attributes=None,
136137
):
137138
"""Return the inference from the specified endpoint.
138139
@@ -153,6 +154,18 @@ def predict(
153154
model you want to host and the resources you want to deploy for hosting it.
154155
inference_id (str): If you provide a value, it is added to the captured data
155156
when you enable data capture on the endpoint (Default: None).
157+
custom_attributes (str): Provides additional information about a request for an
158+
inference submitted to a model hosted at an Amazon SageMaker endpoint.
159+
The information is an opaque value that is forwarded verbatim. You could use this
160+
value, for example, to provide an ID that you can use to track a request or to
161+
provide other metadata that a service endpoint was programmed to process. The value
162+
must consist of no more than 1024 visible US-ASCII characters.
163+
164+
The code in your model is responsible for setting or updating any custom attributes
165+
in the response. If your code does not set this value in the response, an empty
166+
value is returned. For example, if a custom attribute represents the trace ID, your
167+
model can prepend the custom attribute with Trace ID: in your post-processing
168+
function (Default: None).
156169
157170
Returns:
158171
object: Inference for the given input. If a deserializer was specified when creating
@@ -162,7 +175,12 @@ def predict(
162175
"""
163176

164177
request_args = self._create_request_args(
165-
data, initial_args, target_model, target_variant, inference_id
178+
data,
179+
initial_args,
180+
target_model,
181+
target_variant,
182+
inference_id,
183+
custom_attributes,
166184
)
167185
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
168186
return self._handle_response(response)
@@ -180,6 +198,7 @@ def _create_request_args(
180198
target_model=None,
181199
target_variant=None,
182200
inference_id=None,
201+
custom_attributes=None,
183202
):
184203
"""Placeholder docstring"""
185204
args = dict(initial_args) if initial_args else {}
@@ -206,6 +225,9 @@ def _create_request_args(
206225
if inference_id:
207226
args["InferenceId"] = inference_id
208227

228+
if custom_attributes:
229+
args["CustomAttributes"] = custom_attributes
230+
209231
data = self.serializer.serialize(data)
210232

211233
args["Body"] = data

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@
5252
_retrieve_supported_accept_types,
5353
_retrieve_supported_content_types,
5454
)
55+
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn # noqa: F401
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functions for obtaining JumpStart model packages."""
14+
from __future__ import absolute_import
15+
from typing import Optional
16+
from sagemaker.jumpstart.constants import (
17+
JUMPSTART_DEFAULT_REGION_NAME,
18+
)
19+
from sagemaker.jumpstart.utils import (
20+
verify_model_region_and_return_specs,
21+
)
22+
from sagemaker.jumpstart.enums import (
23+
JumpStartScriptScope,
24+
)
25+
26+
27+
def _retrieve_model_package_arn(
28+
model_id: str,
29+
model_version: str,
30+
region: Optional[str],
31+
scope: Optional[str] = None,
32+
tolerate_vulnerable_model: bool = False,
33+
tolerate_deprecated_model: bool = False,
34+
) -> Optional[str]:
35+
"""Retrieves associated model pacakge arn for the model.
36+
37+
Args:
38+
model_id (str): JumpStart model ID of the JumpStart model for which to
39+
retrieve the model package arn.
40+
model_version (str): Version of the JumpStart model for which to retrieve the
41+
model package arn.
42+
region (Optional[str]): Region for which to retrieve the model package arn.
43+
scope (Optional[str]): Scope for which to retrieve the model package arn.
44+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
45+
specifications should be tolerated (exception not raised). If False, raises an
46+
exception if the script used by this version of the model has dependencies with known
47+
security vulnerabilities. (Default: False).
48+
tolerate_deprecated_model (bool): True if deprecated versions of model
49+
specifications should be tolerated (exception not raised). If False, raises
50+
an exception if the version of the model is deprecated. (Default: False).
51+
52+
Returns:
53+
str: the model package arn to use for the model or None.
54+
"""
55+
56+
if region is None:
57+
region = JUMPSTART_DEFAULT_REGION_NAME
58+
59+
model_specs = verify_model_region_and_return_specs(
60+
model_id=model_id,
61+
version=model_version,
62+
scope=scope,
63+
region=region,
64+
tolerate_vulnerable_model=tolerate_vulnerable_model,
65+
tolerate_deprecated_model=tolerate_deprecated_model,
66+
)
67+
68+
if scope == JumpStartScriptScope.INFERENCE:
69+
70+
if model_specs.hosting_model_package_arns is None:
71+
return None
72+
73+
regional_arn = model_specs.hosting_model_package_arns.get(region)
74+
75+
return regional_arn
76+
77+
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")

src/sagemaker/jumpstart/factory/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_model_supports_prepacked_inference,
2626
_retrieve_model_init_kwargs,
2727
_retrieve_model_deploy_kwargs,
28+
_retrieve_model_package_arn,
2829
)
2930
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3031
from sagemaker.jumpstart.constants import (
@@ -289,6 +290,22 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
289290
return kwargs
290291

291292

293+
def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
294+
"""Sets model package arn based on default or override, returns full kwargs."""
295+
296+
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
297+
model_id=kwargs.model_id,
298+
model_version=kwargs.model_version,
299+
scope=JumpStartScriptScope.INFERENCE,
300+
region=kwargs.region,
301+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
302+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
303+
)
304+
305+
kwargs.model_package_arn = model_package_arn
306+
return kwargs
307+
308+
292309
def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
293310
"""Sets extra kwargs based on default or override, returns full kwargs."""
294311

@@ -471,6 +488,7 @@ def get_init_kwargs(
471488
container_log_level: Optional[Union[int, PipelineVariable]] = None,
472489
dependencies: Optional[List[str]] = None,
473490
git_config: Optional[Dict[str, str]] = None,
491+
model_package_arn: Optional[str] = None,
474492
) -> JumpStartModelInitKwargs:
475493
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
476494

@@ -498,6 +516,7 @@ def get_init_kwargs(
498516
git_config=git_config,
499517
tolerate_deprecated_model=tolerate_deprecated_model,
500518
tolerate_vulnerable_model=tolerate_vulnerable_model,
519+
model_package_arn=model_package_arn,
501520
)
502521

503522
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
@@ -526,4 +545,6 @@ def get_init_kwargs(
526545
model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs)
527546
model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs)
528547

548+
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)
549+
529550
return model_init_kwargs

src/sagemaker/jumpstart/model.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import absolute_import
1616
import logging
17+
import re
1718

1819
from typing import Dict, List, Optional, Union
1920
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -30,7 +31,7 @@
3031
)
3132
from sagemaker.jumpstart.utils import is_valid_model_id
3233
from sagemaker.utils import stringify_object
33-
from sagemaker.model import Model
34+
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
3435
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
3536
from sagemaker.predictor import PredictorBase
3637
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
@@ -71,6 +72,7 @@ def __init__(
7172
container_log_level: Optional[Union[int, PipelineVariable]] = None,
7273
dependencies: Optional[List[str]] = None,
7374
git_config: Optional[Dict[str, str]] = None,
75+
model_package_arn: Optional[str] = None,
7476
):
7577
"""Initializes a ``JumpStartModel``.
7678
@@ -249,6 +251,9 @@ def __init__(
249251
>>> 'branch': 'test-branch-git-config',
250252
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}
251253
254+
model_package_arn (Optional[str]): An existing SageMaker Model Package arn,
255+
can be just the name if your account owns the Model Package.
256+
``model_data`` is not required. (Default: None).
252257
Raises:
253258
ValueError: If the model ID is not recognized by JumpStart.
254259
"""
@@ -291,6 +296,7 @@ def _is_valid_model_id_hook():
291296
container_log_level=container_log_level,
292297
dependencies=dependencies,
293298
git_config=git_config,
299+
model_package_arn=model_package_arn,
294300
)
295301

296302
self.orig_predictor_cls = predictor_cls
@@ -301,9 +307,49 @@ def _is_valid_model_id_hook():
301307
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
302308
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
303309
self.region = model_init_kwargs.region
310+
self.model_package_arn = model_init_kwargs.model_package_arn
304311

305312
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
306313

314+
def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
315+
"""Create a SageMaker Model Entity
316+
317+
Args:
318+
args: Positional arguments coming from the caller. This class does not require
319+
any so they are ignored.
320+
321+
kwargs: Keyword arguments coming from the caller. This class does not require
322+
any so they are ignored.
323+
"""
324+
if self.model_package_arn:
325+
# When a ModelPackageArn is provided we just create the Model
326+
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
327+
if match:
328+
model_package_name = match.group(3)
329+
else:
330+
# model_package_arn can be just the name if your account owns the Model Package
331+
model_package_name = self.model_package_arn
332+
container_def = {"ModelPackageName": self.model_package_arn}
333+
334+
if self.env != {}:
335+
container_def["Environment"] = self.env
336+
337+
if self.name is None:
338+
self._base_name = model_package_name
339+
340+
self._set_model_name_if_needed()
341+
342+
self.sagemaker_session.create_model(
343+
self.name,
344+
self.role,
345+
container_def,
346+
vpc_config=self.vpc_config,
347+
enable_network_isolation=self.enable_network_isolation(),
348+
tags=kwargs.get("tags"),
349+
)
350+
else:
351+
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)
352+
307353
def deploy(
308354
self,
309355
initial_instance_count: Optional[int] = None,

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
351351
"inference_enable_network_isolation",
352352
"training_enable_network_isolation",
353353
"resource_name_base",
354+
"hosting_eula_key",
355+
"hosting_model_package_arns",
354356
]
355357

356358
def __init__(self, spec: Dict[str, Any]):
@@ -419,6 +421,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
419421
)
420422
self.resource_name_base: bool = json_obj.get("resource_name_base")
421423

424+
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
425+
426+
self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns")
427+
422428
if self.training_supported:
423429
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
424430
json_obj["training_ecr_specs"]
@@ -574,6 +580,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
574580
"container_log_level",
575581
"dependencies",
576582
"git_config",
583+
"model_package_arn",
577584
]
578585

579586
SERIALIZATION_EXCLUSION_SET = {
@@ -583,6 +590,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
583590
"tolerate_vulnerable_model",
584591
"tolerate_deprecated_model",
585592
"region",
593+
"model_package_arn",
586594
}
587595

588596
def __init__(
@@ -610,6 +618,7 @@ def __init__(
610618
git_config: Optional[Dict[str, str]] = None,
611619
tolerate_vulnerable_model: Optional[bool] = None,
612620
tolerate_deprecated_model: Optional[bool] = None,
621+
model_package_arn: Optional[str] = None,
613622
) -> None:
614623
"""Instantiates JumpStartModelInitKwargs object."""
615624

@@ -636,6 +645,7 @@ def __init__(
636645
self.git_config = git_config
637646
self.tolerate_deprecated_model = tolerate_deprecated_model
638647
self.tolerate_vulnerable_model = tolerate_vulnerable_model
648+
self.model_package_arn = model_package_arn
639649

640650

641651
class JumpStartModelDeployKwargs(JumpStartKwargs):

src/sagemaker/jumpstart/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,17 @@ def verify_model_region_and_return_specs(
402402
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403403
)
404404

405+
if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value:
406+
LOGGER.info(
407+
"Model '%s' requires accepting end-user license agreement (EULA). "
408+
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
409+
model_id,
410+
get_jumpstart_content_bucket(region=region),
411+
region,
412+
".cn" if region.startswith("cn-") else "",
413+
model_specs.hosting_eula_key,
414+
)
415+
405416
if model_specs.deprecated:
406417
if not tolerate_deprecated_model:
407418
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@
22992299
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
23002300
"training_prepacked_script_key": None,
23012301
"hosting_prepacked_artifact_key": None,
2302+
"hosting_model_package_arns": None,
2303+
"hosting_eula_key": None,
23022304
"hyperparameters": [
23032305
{
23042306
"name": "epochs",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
345345
"tolerate_vulnerable_model",
346346
"tolerate_deprecated_model",
347347
"instance_type",
348+
"model_package_arn",
348349
}
349350
assert parent_class_init_args - js_class_init_args == init_args_to_skip
350351

tests/unit/test_predictor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,22 @@ def test_setting_serializer_deserializer_atts_changes_content_accept_types():
614614
predictor.deserializer = PandasDeserializer()
615615
assert predictor.accept == ("text/csv", "application/json")
616616
assert predictor.content_type == "text/csv"
617+
618+
619+
def test_custom_attributes():
620+
sagemaker_session = empty_sagemaker_session()
621+
predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)
622+
623+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint = Mock(
624+
return_value={"Body": io.StringIO("response")}
625+
)
626+
627+
predictor.predict("payload", custom_attributes="custom-attribute")
628+
629+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
630+
EndpointName=ENDPOINT,
631+
ContentType="application/octet-stream",
632+
Accept="*/*",
633+
CustomAttributes="custom-attribute",
634+
Body="payload",
635+
)

0 commit comments

Comments
 (0)