Skip to content

Commit 214e458

Browse files
committed
chore: address PR comments
1 parent b2a6374 commit 214e458

File tree

8 files changed

+56
-62
lines changed

8 files changed

+56
-62
lines changed

src/sagemaker/base_predictor.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import print_function, absolute_import
1515

1616
import abc
17-
from typing import Any, Tuple
17+
from typing import Any, Optional, Tuple, Union
1818

1919
from sagemaker.deprecations import (
2020
deprecated_class,
@@ -205,27 +205,29 @@ def _create_request_args(
205205
):
206206
"""Placeholder docstring"""
207207

208-
js_accept = None
208+
jumpstart_serialized_data: Optional[Union[str, bytes]] = None
209+
jumpstart_accept: Optional[str] = None
210+
jumpstart_content_type: Optional[str] = None
209211

210212
if isinstance(data, JumpStartSerializablePayload):
211213
s3_client = self.sagemaker_session.s3_client
212214
region = self.sagemaker_session._region_name
213215
bucket = get_jumpstart_content_bucket(region)
214216

215-
js_serialized_data = PayloadSerializer(
217+
jumpstart_serialized_data = PayloadSerializer(
216218
bucket=bucket, region=region, s3_client=s3_client
217219
).serialize(data)
218-
js_content_type = data.content_type
219-
js_accept = data.accept
220+
jumpstart_content_type = data.content_type
221+
jumpstart_accept = data.accept
220222

221223
args = dict(initial_args) if initial_args else {}
222224

223225
if "EndpointName" not in args:
224226
args["EndpointName"] = self.endpoint_name
225227

226228
if "ContentType" not in args:
227-
if isinstance(data, JumpStartSerializablePayload):
228-
args["ContentType"] = js_content_type
229+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_content_type:
230+
args["ContentType"] = jumpstart_content_type
229231
else:
230232
args["ContentType"] = (
231233
self.content_type
@@ -234,8 +236,8 @@ def _create_request_args(
234236
)
235237

236238
if "Accept" not in args:
237-
if isinstance(data, JumpStartSerializablePayload) and js_accept:
238-
args["Accept"] = js_accept
239+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_accept:
240+
args["Accept"] = jumpstart_accept
239241
else:
240242
args["Accept"] = (
241243
self.accept if isinstance(self.accept, str) else ", ".join(self.accept)
@@ -254,9 +256,9 @@ def _create_request_args(
254256
args["CustomAttributes"] = custom_attributes
255257

256258
data = (
257-
self.serializer.serialize(data)
258-
if not isinstance(data, JumpStartSerializablePayload)
259-
else js_serialized_data
259+
jumpstart_serialized_data
260+
if isinstance(data, JumpStartSerializablePayload) and jumpstart_serialized_data
261+
else self.serializer.serialize(data)
260262
)
261263

262264
args["Body"] = data

src/sagemaker/jumpstart/accessors.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,19 @@ def get_sagemaker_version() -> str:
3939

4040

4141
class JumpStartS3Accessor(object):
42-
"""Static class for storing and retrieving auxilliary s3 artifacts."""
42+
"""Static class for storing and retrieving auxilliary S3 artifacts."""
43+
44+
@staticmethod
45+
def clear_cache() -> None:
46+
"""Clears LRU caches associated with S3 client and retrieved objects."""
47+
48+
JumpStartS3Accessor._get_default_s3_client.cache_clear()
49+
JumpStartS3Accessor.get_object_cached.cache_clear()
4350

4451
@staticmethod
4552
@functools.lru_cache()
4653
def _get_default_s3_client(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> boto3.client:
47-
"""Returns default s3 client associated with the region.
54+
"""Returns default S3 client associated with the region.
4855
4956
Result is cached so multiple clients in memory are not created.
5057
"""
@@ -58,9 +65,9 @@ def get_object_cached(
5865
region: str = JUMPSTART_DEFAULT_REGION_NAME,
5966
s3_client: Optional[boto3.client] = None,
6067
) -> bytes:
61-
"""Returns s3 object located at the bucket and key.
68+
"""Returns S3 object located at the bucket and key.
6269
63-
Requests are cached so that the same s3 request is never made more
70+
Requests are cached so that the same S3 request is never made more
6471
than once, unless a different region or client is used.
6572
"""
6673
return JumpStartS3Accessor.get_object(
@@ -74,7 +81,7 @@ def get_object(
7481
region: str = JUMPSTART_DEFAULT_REGION_NAME,
7582
s3_client: Optional[boto3.client] = None,
7683
) -> bytes:
77-
"""Returns s3 object located at the bucket and key."""
84+
"""Returns S3 object located at the bucket and key."""
7885
if s3_client is None:
7986
s3_client = JumpStartS3Accessor._get_default_s3_client(region)
8087

src/sagemaker/jumpstart/artifacts/payloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module contains functions for obtaining JumpStart payloads."""
13+
"""This module contains functions for obtaining example payloads for JumpStart models."""
1414
from __future__ import absolute_import
1515
from copy import deepcopy
1616
from typing import Dict, Optional

src/sagemaker/jumpstart/model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ def _is_valid_model_id_hook():
315315
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
316316

317317
def retrieve_default_payload(self) -> JumpStartSerializablePayload:
318-
"""Returns default payload associated with the model.
318+
"""Returns the default payload associated with the model.
319319
320320
Payload can be directly used with the `sagemaker.predictor.Predictor.predict(...)` function.
321321
"""
322-
sample_payloads: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_samples(
322+
payload_options: Optional[List[JumpStartSerializablePayload]] = payloads.retrieve_options(
323323
model_id=self.model_id,
324324
model_version=self.model_version,
325325
region=self.region,
@@ -328,12 +328,10 @@ def retrieve_default_payload(self) -> JumpStartSerializablePayload:
328328
sagemaker_session=self.sagemaker_session,
329329
)
330330

331-
if sample_payloads is None or len(sample_payloads) == 0:
332-
raise NotImplementedError(
333-
f"No default payload supported for model ID '{self.model_id}'."
334-
)
331+
if payload_options is None or len(payload_options) == 0:
332+
return None
335333

336-
return sample_payloads[0]
334+
return payload_options[0]
337335

338336
def _create_sagemaker_model(
339337
self,

src/sagemaker/jumpstart/payload_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module stores payload utilities for SageMaker JumpStart."""
13+
"""This module stores stores inference payload utilities for JumpStart models."""
1414
from __future__ import absolute_import
1515
import base64
1616
import json
17-
from typing import Any, Optional, Union
17+
from typing import Optional, Union
1818
import re
1919
import boto3
2020

@@ -50,14 +50,14 @@ def get_bytes_payload_with_s3_references(
5050
self,
5151
payload_str: str,
5252
) -> bytes:
53-
"""Returns bytes object corresponding to referenced s3 object.
53+
"""Returns bytes object corresponding to referenced S3 object.
5454
5555
Raises:
5656
ValueError: If the raw bytes payload is not formatted correctly.
5757
"""
5858
s3_keys = re.compile(S3_BYTES_REGEX).findall(payload_str)
5959
if len(s3_keys) != 1:
60-
raise ValueError(f"Invalid bytes payload: {payload_str}")
60+
raise ValueError("Invalid bytes payload.")
6161

6262
s3_key = s3_keys[0]
6363
serialized_s3_object = JumpStartS3Accessor.get_object_cached(
@@ -70,7 +70,10 @@ def embed_s3_references_in_str_payload(
7070
self,
7171
payload: str,
7272
) -> str:
73-
"""Embeds s3 references in string payloads."""
73+
"""Inserts serialized S3 content into string payload.
74+
75+
If no S3 content is embedded in payload, original string is returned.
76+
"""
7477
return self._embed_s3_b64_references_in_str_payload(payload_body=payload)
7578

7679
def _embed_s3_b64_references_in_str_payload(
@@ -98,10 +101,12 @@ def _embed_s3_b64_references_in_str_payload(
98101
def embed_s3_references_in_json_payload(
99102
self, payload_body: Union[list, dict, str, int, float]
100103
) -> Union[list, dict, str, int, float]:
101-
"""Finds all s3 references in payload and embeds serialized s3 data.
104+
"""Finds all S3 references in payload and embeds serialized S3 data.
102105
103-
S3 bucket is assumed to be the default JumpStart content bucket. If no s3 references
104-
are found, the payload is returned un-modified.
106+
If no S3 references are found, the payload is returned un-modified.
107+
108+
Raises:
109+
ValueError: If the payload has an unrecognized type.
105110
"""
106111
if isinstance(payload_body, str):
107112
return self.embed_s3_references_in_str_payload(payload_body)
@@ -116,8 +121,12 @@ def embed_s3_references_in_json_payload(
116121
}
117122
raise ValueError(f"Payload has unrecognized type: {type(payload_body)}")
118123

119-
def serialize(self, payload: JumpStartSerializablePayload) -> Any:
120-
"""Returns payload bytes that can be inputted to inference endpoint."""
124+
def serialize(self, payload: JumpStartSerializablePayload) -> Union[str, bytes]:
125+
"""Returns payload string or bytes that can be inputted to inference endpoint.
126+
127+
Raises:
128+
ValueError: If the payload has an unrecognized type.
129+
"""
121130
content_type = MIMEType.from_suffixed_type(payload.content_type)
122131
body = payload.body
123132

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
348348
Args:
349349
json_obj (Dict[str, Any]): Dictionary representation of serializable
350350
payload specs.
351+
352+
Raises:
353+
KeyError: If the dictionary is missing keys.
351354
"""
352355

353356
if json_obj is None:

src/sagemaker/payloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
logger = logging.getLogger(__name__)
2828

2929

30-
def retrieve_samples(
30+
def retrieve_options(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
@@ -78,7 +78,7 @@ def retrieve_samples(
7878
Dict[str, JumpStartSerializablePayload]
7979
] = artifacts._retrieve_default_payloads(
8080
model_id,
81-
model_version, # type: ignore
81+
model_version,
8282
region,
8383
tolerate_vulnerable_model,
8484
tolerate_deprecated_model,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,31 +86,6 @@ def test_prepacked_jumpstart_model(setup):
8686
assert response is not None
8787

8888

89-
def test_default_payload_jumpstart_model(setup):
90-
91-
# DO NOT COMMIT THIS LINE
92-
os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"})
93-
94-
model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16"
95-
96-
model = JumpStartModel(
97-
model_id=model_id,
98-
role=get_sm_session().get_caller_identity_arn(),
99-
sagemaker_session=get_sm_session(),
100-
)
101-
102-
default_payload = model.retrieve_default_payload()
103-
104-
# uses ml.g5.8xlarge instance
105-
predictor = model.deploy(
106-
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
107-
)
108-
109-
response = predictor.predict(default_payload)
110-
111-
assert response is not None
112-
113-
11489
@pytest.mark.skipif(
11590
tests.integ.test_region() not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS,
11691
reason=f"JumpStart gated inference models unavailable in {tests.integ.test_region()}.",

0 commit comments

Comments
 (0)