Skip to content

feat: jumpstart contruct payload utility #4190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions src/sagemaker/jumpstart/payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,87 @@
from __future__ import absolute_import
import base64
import json
from typing import Optional, Union
from typing import Dict, Optional, Union
import re
import boto3

from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
JUMPSTART_DEFAULT_REGION_NAME,
)
from sagemaker.jumpstart.enums import MIMEType
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
from sagemaker.session import Session

S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"


def _construct_payload(
prompt: str,
model_id: str,
model_version: str,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> Optional[JumpStartSerializablePayload]:
"""Returns example payload from prompt.

Args:
prompt (str): String-valued prompt to embed in payload.
model_id (str): JumpStart model ID of the JumpStart model for which to construct
the payload.
model_version (str): Version of the JumpStart model for which to retrieve the
payload.
region (Optional[str]): Region for which to retrieve the
payload. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
security vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model
specifications should be tolerated (exception not raised). If False, raises
an exception if the version of the model is deprecated. (Default: False).
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
Returns:
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
this feature is unavailable for the specified model.
"""
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
model_id=model_id,
model_version=model_version,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)
if payloads is None or len(payloads) == 0:
return None

payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]

prompt_key: Optional[str] = payload_to_use.prompt_key
if prompt_key is None:
return None

payload_body = payload_to_use.body
prompt_key_split = prompt_key.split(".")
for idx, prompt_key in enumerate(prompt_key_split):
if idx < len(prompt_key_split) - 1:
payload_body = payload_body[prompt_key]
else:
payload_body[prompt_key] = prompt

return payload_to_use


class PayloadSerializer:
"""Utility class for serializing payloads associated with JumpStart models.

Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
"content_type",
"accept",
"body",
"prompt_key",
]

_non_serializable_slots = ["raw_payload"]
_non_serializable_slots = ["raw_payload", "prompt_key"]

def __init__(self, spec: Optional[Dict[str, Any]]):
"""Initializes a JumpStartSerializablePayload object from its json representation.
Expand Down Expand Up @@ -364,6 +365,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
self.content_type = json_obj["content_type"]
self.body = json_obj["body"]
accept = json_obj.get("accept")
self.prompt_key = json_obj.get("prompt_key")
if accept:
self.accept = accept

Expand Down
162 changes: 162 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,168 @@
},
},
},
"prompt-key": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gave me a pause, could you rename prompt-key-model ?

"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
"version": "1.0.0",
"min_sdk_version": "2.144.0",
"training_supported": False,
"incremental_training_supported": False,
"hosting_ecr_specs": {
"framework": "djl-deepspeed",
"framework_version": "0.21.0",
"py_version": "py38",
"huggingface_transformers_version": "4.17",
},
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
"hosting_prepacked_artifact_version": "1.0.0",
"inference_vulnerable": False,
"inference_dependencies": [
"accelerate==0.18.0",
"diffusers==0.14.0",
"fsspec==2023.4.0",
"huggingface-hub==0.14.1",
"transformers==4.26.1",
],
"inference_vulnerabilities": [],
"training_vulnerable": False,
"training_dependencies": [],
"training_vulnerabilities": [],
"deprecated": False,
"inference_environment_variables": [
{
"name": "SAGEMAKER_PROGRAM",
"type": "text",
"default": "inference.py",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
"type": "text",
"default": "/opt/ml/model/code",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
"type": "text",
"default": "20",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
"type": "text",
"default": "3600",
"scope": "container",
"required_for_model_class": False,
},
{
"name": "ENDPOINT_SERVER_TIMEOUT",
"type": "int",
"default": 3600,
"scope": "container",
"required_for_model_class": True,
},
{
"name": "MODEL_CACHE_ROOT",
"type": "text",
"default": "/opt/ml/model",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_ENV",
"type": "text",
"default": "1",
"scope": "container",
"required_for_model_class": True,
},
{
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
"type": "int",
"default": 1,
"scope": "container",
"required_for_model_class": True,
},
],
"metrics": [],
"default_inference_instance_type": "ml.g5.8xlarge",
"supported_inference_instance_types": [
"ml.g5.8xlarge",
"ml.g5.xlarge",
"ml.g5.2xlarge",
"ml.g5.4xlarge",
"ml.g5.16xlarge",
"ml.p3.2xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.2xlarge",
"ml.g4dn.4xlarge",
"ml.g4dn.8xlarge",
"ml.g4dn.16xlarge",
],
"model_kwargs": {},
"deploy_kwargs": {},
"predictor_specs": {
"supported_content_types": ["application/json"],
"supported_accept_types": ["application/json"],
"default_content_type": "application/json",
"default_accept_type": "application/json",
},
"inference_enable_network_isolation": True,
"validation_supported": False,
"fine_tuning_supported": False,
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
"default_payloads": {
"Dog": {
"content_type": "application/json",
"prompt_key": "hello.prompt",
"body": {
"hello": {"prompt": "a dog"},
"seed": 43,
},
}
},
"hosting_instance_type_variants": {
"regional_aliases": {
"af-south-1": {
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
},
},
"variants": {
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
},
},
},
"predictor-specs-model": {
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
"url": "https://huggingface.co/google/flan-t5-xxl",
Expand Down
37 changes: 36 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,43 @@
from unittest import TestCase
from mock.mock import patch

from sagemaker.jumpstart.payload_utils import PayloadSerializer
from sagemaker.jumpstart.payload_utils import PayloadSerializer, _construct_payload
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec


class TestConstructPayload(TestCase):
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_construct_payload(self, patched_get_model_specs):
patched_get_model_specs.side_effect = get_special_model_spec

model_id = "prompt-key"
region = "us-west-2"

constructed_payload_body = _construct_payload(
prompt="kobebryant",
model_id=model_id,
model_version="*",
region=region,
).body

self.assertEqual(
{
"hello": {"prompt": "kobebryant"},
"seed": 43,
},
constructed_payload_body,
)

# Unsupported model
self.assertIsNone(
_construct_payload(
prompt="blah",
model_id="default_payloads",
model_version="*",
region=region,
)
)


class TestPayloadSerializer(TestCase):
Expand Down