Skip to content

Commit 80df01c

Browse files
authored
feat: jumpstart contruct payload utility (#4190)
1 parent fcfc402 commit 80df01c

File tree

4 files changed

+270
-4
lines changed

4 files changed

+270
-4
lines changed

src/sagemaker/jumpstart/payload_utils.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,87 @@
1414
from __future__ import absolute_import
1515
import base64
1616
import json
17-
from typing import Optional, Union
17+
from typing import Dict, Optional, Union
1818
import re
1919
import boto3
2020

2121
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
22-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
22+
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
23+
from sagemaker.jumpstart.constants import (
24+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
25+
JUMPSTART_DEFAULT_REGION_NAME,
26+
)
2327
from sagemaker.jumpstart.enums import MIMEType
2428
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2529
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
30+
from sagemaker.session import Session
2631

2732
S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
2833
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
2934

3035

36+
def _construct_payload(
37+
prompt: str,
38+
model_id: str,
39+
model_version: str,
40+
region: Optional[str] = None,
41+
tolerate_vulnerable_model: bool = False,
42+
tolerate_deprecated_model: bool = False,
43+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
44+
) -> Optional[JumpStartSerializablePayload]:
45+
"""Returns example payload from prompt.
46+
47+
Args:
48+
prompt (str): String-valued prompt to embed in payload.
49+
model_id (str): JumpStart model ID of the JumpStart model for which to construct
50+
the payload.
51+
model_version (str): Version of the JumpStart model for which to retrieve the
52+
payload.
53+
region (Optional[str]): Region for which to retrieve the
54+
payload. (Default: None).
55+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
56+
specifications should be tolerated (exception not raised). If False, raises an
57+
exception if the script used by this version of the model has dependencies with known
58+
security vulnerabilities. (Default: False).
59+
tolerate_deprecated_model (bool): True if deprecated versions of model
60+
specifications should be tolerated (exception not raised). If False, raises
61+
an exception if the version of the model is deprecated. (Default: False).
62+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
63+
object, used for SageMaker interactions. If not
64+
specified, one is created using the default AWS configuration
65+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
66+
Returns:
67+
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
68+
this feature is unavailable for the specified model.
69+
"""
70+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
71+
model_id=model_id,
72+
model_version=model_version,
73+
region=region,
74+
tolerate_vulnerable_model=tolerate_vulnerable_model,
75+
tolerate_deprecated_model=tolerate_deprecated_model,
76+
sagemaker_session=sagemaker_session,
77+
)
78+
if payloads is None or len(payloads) == 0:
79+
return None
80+
81+
payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]
82+
83+
prompt_key: Optional[str] = payload_to_use.prompt_key
84+
if prompt_key is None:
85+
return None
86+
87+
payload_body = payload_to_use.body
88+
prompt_key_split = prompt_key.split(".")
89+
for idx, prompt_key in enumerate(prompt_key_split):
90+
if idx < len(prompt_key_split) - 1:
91+
payload_body = payload_body[prompt_key]
92+
else:
93+
payload_body[prompt_key] = prompt
94+
95+
return payload_to_use
96+
97+
3198
class PayloadSerializer:
3299
"""Utility class for serializing payloads associated with JumpStart models.
33100

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,10 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
334334
"content_type",
335335
"accept",
336336
"body",
337+
"prompt_key",
337338
]
338339

339-
_non_serializable_slots = ["raw_payload"]
340+
_non_serializable_slots = ["raw_payload", "prompt_key"]
340341

341342
def __init__(self, spec: Optional[Dict[str, Any]]):
342343
"""Initializes a JumpStartSerializablePayload object from its json representation.
@@ -364,6 +365,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
364365
self.content_type = json_obj["content_type"]
365366
self.body = json_obj["body"]
366367
accept = json_obj.get("accept")
368+
self.prompt_key = json_obj.get("prompt_key")
367369
if accept:
368370
self.accept = accept
369371

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,168 @@
22202220
},
22212221
},
22222222
},
2223+
"prompt-key": {
2224+
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
2225+
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
2226+
"version": "1.0.0",
2227+
"min_sdk_version": "2.144.0",
2228+
"training_supported": False,
2229+
"incremental_training_supported": False,
2230+
"hosting_ecr_specs": {
2231+
"framework": "djl-deepspeed",
2232+
"framework_version": "0.21.0",
2233+
"py_version": "py38",
2234+
"huggingface_transformers_version": "4.17",
2235+
},
2236+
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
2237+
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2238+
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
2239+
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
2240+
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2241+
"hosting_prepacked_artifact_version": "1.0.0",
2242+
"inference_vulnerable": False,
2243+
"inference_dependencies": [
2244+
"accelerate==0.18.0",
2245+
"diffusers==0.14.0",
2246+
"fsspec==2023.4.0",
2247+
"huggingface-hub==0.14.1",
2248+
"transformers==4.26.1",
2249+
],
2250+
"inference_vulnerabilities": [],
2251+
"training_vulnerable": False,
2252+
"training_dependencies": [],
2253+
"training_vulnerabilities": [],
2254+
"deprecated": False,
2255+
"inference_environment_variables": [
2256+
{
2257+
"name": "SAGEMAKER_PROGRAM",
2258+
"type": "text",
2259+
"default": "inference.py",
2260+
"scope": "container",
2261+
"required_for_model_class": True,
2262+
},
2263+
{
2264+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
2265+
"type": "text",
2266+
"default": "/opt/ml/model/code",
2267+
"scope": "container",
2268+
"required_for_model_class": False,
2269+
},
2270+
{
2271+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
2272+
"type": "text",
2273+
"default": "20",
2274+
"scope": "container",
2275+
"required_for_model_class": False,
2276+
},
2277+
{
2278+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
2279+
"type": "text",
2280+
"default": "3600",
2281+
"scope": "container",
2282+
"required_for_model_class": False,
2283+
},
2284+
{
2285+
"name": "ENDPOINT_SERVER_TIMEOUT",
2286+
"type": "int",
2287+
"default": 3600,
2288+
"scope": "container",
2289+
"required_for_model_class": True,
2290+
},
2291+
{
2292+
"name": "MODEL_CACHE_ROOT",
2293+
"type": "text",
2294+
"default": "/opt/ml/model",
2295+
"scope": "container",
2296+
"required_for_model_class": True,
2297+
},
2298+
{
2299+
"name": "SAGEMAKER_ENV",
2300+
"type": "text",
2301+
"default": "1",
2302+
"scope": "container",
2303+
"required_for_model_class": True,
2304+
},
2305+
{
2306+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
2307+
"type": "int",
2308+
"default": 1,
2309+
"scope": "container",
2310+
"required_for_model_class": True,
2311+
},
2312+
],
2313+
"metrics": [],
2314+
"default_inference_instance_type": "ml.g5.8xlarge",
2315+
"supported_inference_instance_types": [
2316+
"ml.g5.8xlarge",
2317+
"ml.g5.xlarge",
2318+
"ml.g5.2xlarge",
2319+
"ml.g5.4xlarge",
2320+
"ml.g5.16xlarge",
2321+
"ml.p3.2xlarge",
2322+
"ml.g4dn.xlarge",
2323+
"ml.g4dn.2xlarge",
2324+
"ml.g4dn.4xlarge",
2325+
"ml.g4dn.8xlarge",
2326+
"ml.g4dn.16xlarge",
2327+
],
2328+
"model_kwargs": {},
2329+
"deploy_kwargs": {},
2330+
"predictor_specs": {
2331+
"supported_content_types": ["application/json"],
2332+
"supported_accept_types": ["application/json"],
2333+
"default_content_type": "application/json",
2334+
"default_accept_type": "application/json",
2335+
},
2336+
"inference_enable_network_isolation": True,
2337+
"validation_supported": False,
2338+
"fine_tuning_supported": False,
2339+
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
2340+
"default_payloads": {
2341+
"Dog": {
2342+
"content_type": "application/json",
2343+
"prompt_key": "hello.prompt",
2344+
"body": {
2345+
"hello": {"prompt": "a dog"},
2346+
"seed": 43,
2347+
},
2348+
}
2349+
},
2350+
"hosting_instance_type_variants": {
2351+
"regional_aliases": {
2352+
"af-south-1": {
2353+
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
2354+
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
2355+
},
2356+
},
2357+
"variants": {
2358+
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2359+
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2360+
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2361+
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2362+
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2363+
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2364+
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2365+
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2366+
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2367+
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2368+
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2369+
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2370+
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2371+
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2372+
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2373+
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2374+
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2375+
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2376+
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2377+
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2378+
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2379+
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2380+
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2381+
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2382+
},
2383+
},
2384+
},
22232385
"predictor-specs-model": {
22242386
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
22252387
"url": "https://huggingface.co/google/flan-t5-xxl",

tests/unit/sagemaker/jumpstart/test_payload_utils.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,43 @@
1515
from unittest import TestCase
1616
from mock.mock import patch
1717

18-
from sagemaker.jumpstart.payload_utils import PayloadSerializer
18+
from sagemaker.jumpstart.payload_utils import PayloadSerializer, _construct_payload
1919
from sagemaker.jumpstart.types import JumpStartSerializablePayload
20+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
21+
22+
23+
class TestConstructPayload(TestCase):
24+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
25+
def test_construct_payload(self, patched_get_model_specs):
26+
patched_get_model_specs.side_effect = get_special_model_spec
27+
28+
model_id = "prompt-key"
29+
region = "us-west-2"
30+
31+
constructed_payload_body = _construct_payload(
32+
prompt="kobebryant",
33+
model_id=model_id,
34+
model_version="*",
35+
region=region,
36+
).body
37+
38+
self.assertEqual(
39+
{
40+
"hello": {"prompt": "kobebryant"},
41+
"seed": 43,
42+
},
43+
constructed_payload_body,
44+
)
45+
46+
# Unsupported model
47+
self.assertIsNone(
48+
_construct_payload(
49+
prompt="blah",
50+
model_id="default_payloads",
51+
model_version="*",
52+
region=region,
53+
)
54+
)
2055

2156

2257
class TestPayloadSerializer(TestCase):

0 commit comments

Comments
 (0)