Skip to content

Commit 2fc2264

Browse files
authored
Merge branch 'master' into feat/js-estimator-infra-check-flag
2 parents 0e85a1c + 7f6f3f9 commit 2fc2264

File tree

12 files changed

+1005
-10
lines changed

12 files changed

+1005
-10
lines changed

src/sagemaker/base_predictor.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import print_function, absolute_import
1515

1616
import abc
17-
from typing import Any, Tuple
17+
from typing import Any, Optional, Tuple, Union
1818

1919
from sagemaker.deprecations import (
2020
deprecated_class,
@@ -32,6 +32,9 @@
3232
StreamDeserializer,
3333
StringDeserializer,
3434
)
35+
from sagemaker.jumpstart.payload_utils import PayloadSerializer
36+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
37+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
3538
from sagemaker.model_monitor import (
3639
DataCaptureConfig,
3740
DefaultModelMonitor,
@@ -201,20 +204,44 @@ def _create_request_args(
201204
custom_attributes=None,
202205
):
203206
"""Placeholder docstring"""
207+
208+
jumpstart_serialized_data: Optional[Union[str, bytes]] = None
209+
jumpstart_accept: Optional[str] = None
210+
jumpstart_content_type: Optional[str] = None
211+
212+
if isinstance(data, JumpStartSerializablePayload):
213+
s3_client = self.sagemaker_session.s3_client
214+
region = self.sagemaker_session._region_name
215+
bucket = get_jumpstart_content_bucket(region)
216+
217+
jumpstart_serialized_data = PayloadSerializer(
218+
bucket=bucket, region=region, s3_client=s3_client
219+
).serialize(data)
220+
jumpstart_content_type = data.content_type
221+
jumpstart_accept = data.accept
222+
204223
args = dict(initial_args) if initial_args else {}
205224

206225
if "EndpointName" not in args:
207226
args["EndpointName"] = self.endpoint_name
208227

209228
if "ContentType" not in args:
210-
args["ContentType"] = (
211-
self.content_type
212-
if isinstance(self.content_type, str)
213-
else ", ".join(self.content_type)
214-
)
229+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_content_type:
230+
args["ContentType"] = jumpstart_content_type
231+
else:
232+
args["ContentType"] = (
233+
self.content_type
234+
if isinstance(self.content_type, str)
235+
else ", ".join(self.content_type)
236+
)
215237

216238
if "Accept" not in args:
217-
args["Accept"] = self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
239+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_accept:
240+
args["Accept"] = jumpstart_accept
241+
else:
242+
args["Accept"] = (
243+
self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
244+
)
218245

219246
if target_model:
220247
args["TargetModel"] = target_model
@@ -228,7 +255,11 @@ def _create_request_args(
228255
if custom_attributes:
229256
args["CustomAttributes"] = custom_attributes
230257

231-
data = self.serializer.serialize(data)
258+
data = (
259+
jumpstart_serialized_data
260+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data
261+
else self.serializer.serialize(data)
262+
)
232263

233264
args["Body"] = data
234265
return args

src/sagemaker/jumpstart/accessors.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains accessors related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
import functools
1516
from typing import Any, Dict, List, Optional
1617
import boto3
1718

@@ -37,6 +38,88 @@ def get_sagemaker_version() -> str:
3738
return SageMakerSettings._parsed_sagemaker_version
3839

3940

41+
class JumpStartS3PayloadAccessor(object):
42+
"""Static class for storing and retrieving S3 payload artifacts."""
43+
44+
MAX_CACHE_SIZE_BYTES = int(100 * 1e6)
45+
MAX_PAYLOAD_SIZE_BYTES = int(6 * 1e6)
46+
47+
CACHE_SIZE = MAX_CACHE_SIZE_BYTES // MAX_PAYLOAD_SIZE_BYTES
48+
49+
@staticmethod
50+
def clear_cache() -> None:
51+
"""Clears LRU caches associated with S3 client and retrieved objects."""
52+
53+
JumpStartS3PayloadAccessor._get_default_s3_client.cache_clear()
54+
JumpStartS3PayloadAccessor.get_object_cached.cache_clear()
55+
56+
@staticmethod
57+
@functools.lru_cache()
58+
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client:
59+
"""Returns default S3 client associated with the region.
60+
61+
Result is cached so multiple clients in memory are not created.
62+
"""
63+
return boto3.client("s3", region_name=region)
64+
65+
@staticmethod
66+
@functools.lru_cache(maxsize=CACHE_SIZE)
67+
def get_object_cached(
68+
bucket: str,
69+
key: str,
70+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
71+
s3_client: Optional[boto3.client] = None,
72+
) -> bytes:
73+
"""Returns S3 object located at the bucket and key.
74+
75+
Requests are cached so that the same S3 request is never made more
76+
than once, unless a different region or client is used.
77+
"""
78+
return JumpStartS3PayloadAccessor.get_object(
79+
bucket=bucket, key=key, region=region, s3_client=s3_client
80+
)
81+
82+
@staticmethod
83+
def _get_object_size_bytes(
84+
bucket: str,
85+
key: str,
86+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
87+
s3_client: Optional[boto3.client] = None,
88+
) -> bytes:
89+
"""Returns size in bytes of S3 object using S3.HeadObject operation."""
90+
if s3_client is None:
91+
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)
92+
93+
return s3_client.head_object(Bucket=bucket, Key=key)["ContentLength"]
94+
95+
@staticmethod
96+
def get_object(
97+
bucket: str,
98+
key: str,
99+
region: str = JUMPSTART_DEFAULT_REGION_NAME,
100+
s3_client: Optional[boto3.client] = None,
101+
) -> bytes:
102+
"""Returns S3 object located at the bucket and key.
103+
104+
Raises:
105+
ValueError: The object size is too large.
106+
"""
107+
if s3_client is None:
108+
s3_client = JumpStartS3PayloadAccessor._get_default_s3_client(region)
109+
110+
object_size_bytes = JumpStartS3PayloadAccessor._get_object_size_bytes(
111+
bucket=bucket, key=key, region=region, s3_client=s3_client
112+
)
113+
if object_size_bytes > JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES:
114+
raise ValueError(
115+
f"s3://{bucket}/{key} has size of {object_size_bytes} bytes, "
116+
"which exceeds maximum allowed size of "
117+
f"{JumpStartS3PayloadAccessor.MAX_PAYLOAD_SIZE_BYTES} bytes."
118+
)
119+
120+
return s3_client.get_object(Bucket=bucket, Key=key)["Body"].read()
121+
122+
40123
class JumpStartModelsAccessor(object):
41124
"""Static class for storing the JumpStart models cache."""
42125

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@
6161
_retrieve_model_package_arn,
6262
_retrieve_model_package_model_artifact_s3_uri,
6363
)
64+
from sagemaker.jumpstart.artifacts.payloads import ( # noqa: F401
65+
_retrieve_example_payloads,
66+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functions to obtain JumpStart model payloads."""
14+
from __future__ import absolute_import
15+
from copy import deepcopy
16+
from typing import Dict, Optional
17+
from sagemaker.jumpstart.constants import (
18+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19+
JUMPSTART_DEFAULT_REGION_NAME,
20+
)
21+
from sagemaker.jumpstart.enums import (
22+
JumpStartScriptScope,
23+
)
24+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
25+
from sagemaker.jumpstart.utils import (
26+
verify_model_region_and_return_specs,
27+
)
28+
from sagemaker.session import Session
29+
30+
31+
def _retrieve_example_payloads(
32+
model_id: str,
33+
model_version: str,
34+
region: Optional[str],
35+
tolerate_vulnerable_model: bool = False,
36+
tolerate_deprecated_model: bool = False,
37+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
38+
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
39+
"""Returns example payloads.
40+
41+
Args:
42+
model_id (str): JumpStart model ID of the JumpStart model for which to
43+
get example payloads.
44+
model_version (str): Version of the JumpStart model for which to retrieve the
45+
example payloads.
46+
region (Optional[str]): Region for which to retrieve the
47+
example payloads.
48+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
49+
specifications should be tolerated (exception not raised). If False, raises an
50+
exception if the script used by this version of the model has dependencies with known
51+
security vulnerabilities. (Default: False).
52+
tolerate_deprecated_model (bool): True if deprecated versions of model
53+
specifications should be tolerated (exception not raised). If False, raises
54+
an exception if the version of the model is deprecated. (Default: False).
55+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
56+
object, used for SageMaker interactions. If not
57+
specified, one is created using the default AWS configuration
58+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
59+
Returns:
60+
Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases
61+
to the serializable payload object.
62+
"""
63+
64+
if region is None:
65+
region = JUMPSTART_DEFAULT_REGION_NAME
66+
67+
model_specs = verify_model_region_and_return_specs(
68+
model_id=model_id,
69+
version=model_version,
70+
scope=JumpStartScriptScope.INFERENCE,
71+
region=region,
72+
tolerate_vulnerable_model=tolerate_vulnerable_model,
73+
tolerate_deprecated_model=tolerate_deprecated_model,
74+
sagemaker_session=sagemaker_session,
75+
)
76+
77+
default_payloads = model_specs.default_payloads
78+
79+
if default_payloads:
80+
for payload in default_payloads.values():
81+
payload.accept = getattr(
82+
payload, "accept", model_specs.predictor_specs.default_accept_type
83+
)
84+
85+
return deepcopy(default_payloads) if default_payloads else None

src/sagemaker/jumpstart/model.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717

1818
from typing import Dict, List, Optional, Union
19+
from sagemaker import payloads
1920
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2021
from sagemaker.base_deserializers import BaseDeserializer
2122
from sagemaker.base_serializers import BaseSerializer
@@ -28,6 +29,7 @@
2829
get_deploy_kwargs,
2930
get_init_kwargs,
3031
)
32+
from sagemaker.jumpstart.types import JumpStartSerializablePayload
3133
from sagemaker.jumpstart.utils import is_valid_model_id
3234
from sagemaker.utils import stringify_object
3335
from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model
@@ -312,6 +314,46 @@ def _is_valid_model_id_hook():
312314

313315
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
314316

317+
def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
318+
"""Returns all example payloads associated with the model.
319+
320+
Raises:
321+
NotImplementedError: If the scope is not supported.
322+
ValueError: If the combination of arguments specified is not supported.
323+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
324+
known security vulnerabilities.
325+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
326+
"""
327+
return payloads.retrieve_all_examples(
328+
model_id=self.model_id,
329+
model_version=self.model_version,
330+
region=self.region,
331+
tolerate_deprecated_model=self.tolerate_deprecated_model,
332+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
333+
sagemaker_session=self.sagemaker_session,
334+
)
335+
336+
def retrieve_example_payload(self) -> JumpStartSerializablePayload:
337+
"""Returns the example payload associated with the model.
338+
339+
Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function.
340+
341+
Raises:
342+
NotImplementedError: If the scope is not supported.
343+
ValueError: If the combination of arguments specified is not supported.
344+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
345+
known security vulnerabilities.
346+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
347+
"""
348+
return payloads.retrieve_example(
349+
model_id=self.model_id,
350+
model_version=self.model_version,
351+
region=self.region,
352+
tolerate_deprecated_model=self.tolerate_deprecated_model,
353+
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
354+
sagemaker_session=self.sagemaker_session,
355+
)
356+
315357
def _create_sagemaker_model(
316358
self,
317359
instance_type=None,

0 commit comments

Comments
 (0)