Skip to content

feat: jumpstart EULA models #3999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 15, 2023
24 changes: 23 additions & 1 deletion src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def predict(
target_model=None,
target_variant=None,
inference_id=None,
custom_attributes=None,
):
"""Return the inference from the specified endpoint.

Expand All @@ -153,6 +154,18 @@ def predict(
model you want to host and the resources you want to deploy for hosting it.
inference_id (str): If you provide a value, it is added to the captured data
when you enable data capture on the endpoint (Default: None).
custom_attributes (str): Provides additional information about a request for an
inference submitted to a model hosted at an Amazon SageMaker endpoint.
The information is an opaque value that is forwarded verbatim. You could use this
value, for example, to provide an ID that you can use to track a request or to
provide other metadata that a service endpoint was programmed to process. The value
must consist of no more than 1024 visible US-ASCII characters.

The code in your model is responsible for setting or updating any custom attributes
in the response. If your code does not set this value in the response, an empty
value is returned. For example, if a custom attribute represents the trace ID, your
model can prepend the custom attribute with Trace ID: in your post-processing
function (Default: None).

Returns:
object: Inference for the given input. If a deserializer was specified when creating
Expand All @@ -162,7 +175,12 @@ def predict(
"""

request_args = self._create_request_args(
data, initial_args, target_model, target_variant, inference_id
data,
initial_args,
target_model,
target_variant,
inference_id,
custom_attributes,
)
response = self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint(**request_args)
return self._handle_response(response)
Expand All @@ -180,6 +198,7 @@ def _create_request_args(
target_model=None,
target_variant=None,
inference_id=None,
custom_attributes=None,
):
"""Placeholder docstring"""
args = dict(initial_args) if initial_args else {}
Expand All @@ -206,6 +225,9 @@ def _create_request_args(
if inference_id:
args["InferenceId"] = inference_id

if custom_attributes:
args["CustomAttributes"] = custom_attributes

data = self.serializer.serialize(data)

args["Body"] = data
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@
_retrieve_supported_accept_types,
_retrieve_supported_content_types,
)
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn # noqa: F401
77 changes: 77 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functions for obtaining JumpStart model packages."""
from __future__ import absolute_import
from typing import Optional
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.utils import (
verify_model_region_and_return_specs,
)
from sagemaker.jumpstart.enums import (
JumpStartScriptScope,
)


def _retrieve_model_package_arn(
model_id: str,
model_version: str,
region: Optional[str],
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
) -> Optional[str]:
"""Retrieves associated model pacakge arn for the model.

Args:
model_id (str): JumpStart model ID of the JumpStart model for which to
retrieve the model package arn.
model_version (str): Version of the JumpStart model for which to retrieve the
model package arn.
region (Optional[str]): Region for which to retrieve the model package arn.
scope (Optional[str]): Scope for which to retrieve the model package arn.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).

Returns:
str: the model package arn to use for the model or None.
"""

if region is None:
region = JUMPSTART_DEFAULT_REGION_NAME

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
)

if scope == JumpStartScriptScope.INFERENCE:

if model_specs.hosting_model_package_arns is None:
return None

regional_arn = model_specs.hosting_model_package_arns.get(region)

return regional_arn

raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
21 changes: 21 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_model_supports_prepacked_inference,
_retrieve_model_init_kwargs,
_retrieve_model_deploy_kwargs,
_retrieve_model_package_arn,
)
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
from sagemaker.jumpstart.constants import (
Expand Down Expand Up @@ -289,6 +290,22 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
return kwargs


def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets model package arn based on default or override, returns full kwargs."""

model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
model_id=kwargs.model_id,
model_version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
)

kwargs.model_package_arn = model_package_arn
return kwargs


def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets extra kwargs based on default or override, returns full kwargs."""

Expand Down Expand Up @@ -471,6 +488,7 @@ def get_init_kwargs(
container_log_level: Optional[Union[int, PipelineVariable]] = None,
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
model_package_arn: Optional[str] = None,
) -> JumpStartModelInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -498,6 +516,7 @@ def get_init_kwargs(
git_config=git_config,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
model_package_arn=model_package_arn,
)

model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
Expand Down Expand Up @@ -526,4 +545,6 @@ def get_init_kwargs(
model_init_kwargs = _add_extra_model_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_role_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
48 changes: 47 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import absolute_import
import logging
import re

from typing import Dict, List, Optional, Union
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
Expand All @@ -30,7 +31,7 @@
)
from sagemaker.jumpstart.utils import is_valid_model_id
from sagemaker.utils import stringify_object
from sagemaker.model import Model
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
from sagemaker.predictor import PredictorBase
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
container_log_level: Optional[Union[int, PipelineVariable]] = None,
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
model_package_arn: Optional[str] = None,
):
"""Initializes a ``JumpStartModel``.

Expand Down Expand Up @@ -249,6 +251,9 @@ def __init__(
>>> 'branch': 'test-branch-git-config',
>>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'}

model_package_arn (Optional[str]): An existing SageMaker Model Package arn,
can be just the name if your account owns the Model Package.
``model_data`` is not required. (Default: None).
Raises:
ValueError: If the model ID is not recognized by JumpStart.
"""
Expand Down Expand Up @@ -291,6 +296,7 @@ def _is_valid_model_id_hook():
container_log_level=container_log_level,
dependencies=dependencies,
git_config=git_config,
model_package_arn=model_package_arn,
)

self.orig_predictor_cls = predictor_cls
Expand All @@ -301,9 +307,49 @@ def _is_valid_model_id_hook():
self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.model_package_arn = model_init_kwargs.model_package_arn

super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())

def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument
"""Create a SageMaker Model Entity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit^2: missing period . at the end of the docstring.


Args:
args: Positional arguments coming from the caller. This class does not require
any so they are ignored.

kwargs: Keyword arguments coming from the caller. This class does not require
any so they are ignored.
"""
if self.model_package_arn:
# When a ModelPackageArn is provided we just create the Model
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
if match:
model_package_name = match.group(3)
else:
# model_package_arn can be just the name if your account owns the Model Package
model_package_name = self.model_package_arn
container_def = {"ModelPackageName": self.model_package_arn}

if self.env != {}:
container_def["Environment"] = self.env

if self.name is None:
self._base_name = model_package_name

self._set_model_name_if_needed()

self.sagemaker_session.create_model(
self.name,
self.role,
container_def,
vpc_config=self.vpc_config,
enable_network_isolation=self.enable_network_isolation(),
tags=kwargs.get("tags"),
)
else:
super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs)

def deploy(
self,
initial_instance_count: Optional[int] = None,
Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
"inference_enable_network_isolation",
"training_enable_network_isolation",
"resource_name_base",
"hosting_eula_key",
"hosting_model_package_arns",
]

def __init__(self, spec: Dict[str, Any]):
Expand Down Expand Up @@ -419,6 +421,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
)
self.resource_name_base: bool = json_obj.get("resource_name_base")

self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")

self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns")

if self.training_supported:
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
json_obj["training_ecr_specs"]
Expand Down Expand Up @@ -574,6 +580,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"container_log_level",
"dependencies",
"git_config",
"model_package_arn",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -583,6 +590,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
"region",
"model_package_arn",
}

def __init__(
Expand Down Expand Up @@ -610,6 +618,7 @@ def __init__(
git_config: Optional[Dict[str, str]] = None,
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
model_package_arn: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelInitKwargs object."""

Expand All @@ -636,6 +645,7 @@ def __init__(
self.git_config = git_config
self.tolerate_deprecated_model = tolerate_deprecated_model
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.model_package_arn = model_package_arn


class JumpStartModelDeployKwargs(JumpStartKwargs):
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,17 @@ def verify_model_region_and_return_specs(
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
)

if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value:
LOGGER.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See prior comment, we can perhaps give them the ability to retrieve the EULA from our public bucket

"Model '%s' requires accepting end-user license agreement (EULA). "
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
model_id,
get_jumpstart_content_bucket(region=region),
region,
".cn" if region.startswith("cn-") else "",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callout: this may not work for gov-cloud.

model_specs.hosting_eula_key,
)

if model_specs.deprecated:
if not tolerate_deprecated_model:
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,8 @@
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
"training_prepacked_script_key": None,
"hosting_prepacked_artifact_key": None,
"hosting_model_package_arns": None,
"hosting_eula_key": None,
"hyperparameters": [
{
"name": "epochs",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
"instance_type",
"model_package_arn",
}
assert parent_class_init_args - js_class_init_args == init_args_to_skip

Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,22 @@ def test_setting_serializer_deserializer_atts_changes_content_accept_types():
predictor.deserializer = PandasDeserializer()
assert predictor.accept == ("text/csv", "application/json")
assert predictor.content_type == "text/csv"


def test_custom_attributes():
sagemaker_session = empty_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session=sagemaker_session)

sagemaker_session.sagemaker_runtime_client.invoke_endpoint = Mock(
return_value={"Body": io.StringIO("response")}
)

predictor.predict("payload", custom_attributes="custom-attribute")

sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
EndpointName=ENDPOINT,
ContentType="application/octet-stream",
Accept="*/*",
CustomAttributes="custom-attribute",
Body="payload",
)