Skip to content

Commit 20b5b66

Browse files
committed
Merge remote-tracking branch 'origin' into feat/jumpstart-extract-generated-text-from-response
2 parents 0be4fd5 + b8e3e05 commit 20b5b66

File tree

7 files changed

+286
-13
lines changed

7 files changed

+286
-13
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## v2.194.0 (2023-10-19)
4+
5+
### Features
6+
7+
* Added register step in Jumpstart model
8+
* jumpstart instance specific metric definitions
9+
10+
### Bug Fixes and Other Changes
11+
12+
* Updates for DJL 0.24.0 Release
13+
* use getter for resource-metadata dict
14+
* add method to Model class to check if repack is needed
15+
316
## v2.193.0 (2023-10-18)
417

518
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.193.1.dev0
1+
2.194.1.dev0

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def schedule(
281281
Args:
282282
pipeline_name (str): The SageMaker Pipeline name that will be scheduled.
283283
schedule_expression (str): The expression that defines when the schedule runs. It supports
284-
at expression, rate expression and cron expression. See '''https://docs.aws.amazon.com\
285-
/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-\
286-
request-ScheduleExpression''' for more details.
284+
at expression, rate expression and cron expression. See the
285+
`CreateSchedule API
286+
<https://docs.aws.amazon.com/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedulerequest-ScheduleExpression>`_
287+
for more details.
287288
state (str): Specifies whether the schedule is enabled or disabled. Valid values are
288-
ENABLED and DISABLED. See '''https://docs.aws.amazon.com/scheduler/latest/APIReference\
289-
/API_CreateSchedule.html#scheduler-CreateSchedule-request-State'''
289+
ENABLED and DISABLED. See the `State request parameter
290+
<https://docs.aws.amazon.com/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request-State>`_
290291
for more details. If not specified, it will default to ENABLED.
291292
start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin
292293
invoking its target. Depending on the schedule’s recurrence expression, invocations

src/sagemaker/jumpstart/payload_utils.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919
import boto3
2020

2121
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
22-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
22+
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
23+
from sagemaker.jumpstart.constants import (
24+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
25+
JUMPSTART_DEFAULT_REGION_NAME,
26+
)
2327
from sagemaker.jumpstart.enums import MIMEType
2428
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2529
from sagemaker.jumpstart.utils import (
2630
get_jumpstart_content_bucket,
2731
)
2832
from sagemaker.session import Session
29-
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
30-
from sagemaker.jumpstart.constants import (
31-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
32-
)
33+
3334

3435
S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
3536
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
@@ -52,6 +53,67 @@ def _extract_field_from_json(
5253
return curr_json[key]
5354

5455

56+
def _construct_payload(
57+
prompt: str,
58+
model_id: str,
59+
model_version: str,
60+
region: Optional[str] = None,
61+
tolerate_vulnerable_model: bool = False,
62+
tolerate_deprecated_model: bool = False,
63+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
64+
) -> Optional[JumpStartSerializablePayload]:
65+
"""Returns example payload from prompt.
66+
Args:
67+
prompt (str): String-valued prompt to embed in payload.
68+
model_id (str): JumpStart model ID of the JumpStart model for which to construct
69+
the payload.
70+
model_version (str): Version of the JumpStart model for which to retrieve the
71+
payload.
72+
region (Optional[str]): Region for which to retrieve the
73+
payload. (Default: None).
74+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
75+
specifications should be tolerated (exception not raised). If False, raises an
76+
exception if the script used by this version of the model has dependencies with known
77+
security vulnerabilities. (Default: False).
78+
tolerate_deprecated_model (bool): True if deprecated versions of model
79+
specifications should be tolerated (exception not raised). If False, raises
80+
an exception if the version of the model is deprecated. (Default: False).
81+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
82+
object, used for SageMaker interactions. If not
83+
specified, one is created using the default AWS configuration
84+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
85+
Returns:
86+
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
87+
this feature is unavailable for the specified model.
88+
"""
89+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
90+
model_id=model_id,
91+
model_version=model_version,
92+
region=region,
93+
tolerate_vulnerable_model=tolerate_vulnerable_model,
94+
tolerate_deprecated_model=tolerate_deprecated_model,
95+
sagemaker_session=sagemaker_session,
96+
)
97+
if payloads is None or len(payloads) == 0:
98+
return None
99+
100+
payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]
101+
102+
prompt_key: Optional[str] = payload_to_use.prompt_key
103+
if prompt_key is None:
104+
return None
105+
106+
payload_body = payload_to_use.body
107+
prompt_key_split = prompt_key.split(".")
108+
for idx, prompt_key in enumerate(prompt_key_split):
109+
if idx < len(prompt_key_split) - 1:
110+
payload_body = payload_body[prompt_key]
111+
else:
112+
payload_body[prompt_key] = prompt
113+
114+
return payload_to_use
115+
116+
55117
def _extract_generated_text_from_response(
56118
response: dict,
57119
model_id: str,

src/sagemaker/jumpstart/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,10 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
335335
"accept",
336336
"body",
337337
"generated_text_response_key",
338+
"prompt_key",
338339
]
339340

340-
_non_serializable_slots = ["raw_payload"]
341+
_non_serializable_slots = ["raw_payload", "prompt_key"]
341342

342343
def __init__(self, spec: Optional[Dict[str, Any]]):
343344
"""Initializes a JumpStartSerializablePayload object from its json representation.
@@ -366,6 +367,7 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
366367
self.body = json_obj["body"]
367368
accept = json_obj.get("accept")
368369
self.generated_text_response_key = json_obj.get("generated_text_response_key")
370+
self.prompt_key = json_obj.get("prompt_key")
369371
if accept:
370372
self.accept = accept
371373

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,168 @@
23832383
},
23842384
},
23852385
},
2386+
"prompt-key": {
2387+
"model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16",
2388+
"url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth",
2389+
"version": "1.0.0",
2390+
"min_sdk_version": "2.144.0",
2391+
"training_supported": False,
2392+
"incremental_training_supported": False,
2393+
"hosting_ecr_specs": {
2394+
"framework": "djl-deepspeed",
2395+
"framework_version": "0.21.0",
2396+
"py_version": "py38",
2397+
"huggingface_transformers_version": "4.17",
2398+
},
2399+
"hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st"
2400+
"able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2401+
"hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz",
2402+
"hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/"
2403+
"infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz",
2404+
"hosting_prepacked_artifact_version": "1.0.0",
2405+
"inference_vulnerable": False,
2406+
"inference_dependencies": [
2407+
"accelerate==0.18.0",
2408+
"diffusers==0.14.0",
2409+
"fsspec==2023.4.0",
2410+
"huggingface-hub==0.14.1",
2411+
"transformers==4.26.1",
2412+
],
2413+
"inference_vulnerabilities": [],
2414+
"training_vulnerable": False,
2415+
"training_dependencies": [],
2416+
"training_vulnerabilities": [],
2417+
"deprecated": False,
2418+
"inference_environment_variables": [
2419+
{
2420+
"name": "SAGEMAKER_PROGRAM",
2421+
"type": "text",
2422+
"default": "inference.py",
2423+
"scope": "container",
2424+
"required_for_model_class": True,
2425+
},
2426+
{
2427+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
2428+
"type": "text",
2429+
"default": "/opt/ml/model/code",
2430+
"scope": "container",
2431+
"required_for_model_class": False,
2432+
},
2433+
{
2434+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
2435+
"type": "text",
2436+
"default": "20",
2437+
"scope": "container",
2438+
"required_for_model_class": False,
2439+
},
2440+
{
2441+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
2442+
"type": "text",
2443+
"default": "3600",
2444+
"scope": "container",
2445+
"required_for_model_class": False,
2446+
},
2447+
{
2448+
"name": "ENDPOINT_SERVER_TIMEOUT",
2449+
"type": "int",
2450+
"default": 3600,
2451+
"scope": "container",
2452+
"required_for_model_class": True,
2453+
},
2454+
{
2455+
"name": "MODEL_CACHE_ROOT",
2456+
"type": "text",
2457+
"default": "/opt/ml/model",
2458+
"scope": "container",
2459+
"required_for_model_class": True,
2460+
},
2461+
{
2462+
"name": "SAGEMAKER_ENV",
2463+
"type": "text",
2464+
"default": "1",
2465+
"scope": "container",
2466+
"required_for_model_class": True,
2467+
},
2468+
{
2469+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
2470+
"type": "int",
2471+
"default": 1,
2472+
"scope": "container",
2473+
"required_for_model_class": True,
2474+
},
2475+
],
2476+
"metrics": [],
2477+
"default_inference_instance_type": "ml.g5.8xlarge",
2478+
"supported_inference_instance_types": [
2479+
"ml.g5.8xlarge",
2480+
"ml.g5.xlarge",
2481+
"ml.g5.2xlarge",
2482+
"ml.g5.4xlarge",
2483+
"ml.g5.16xlarge",
2484+
"ml.p3.2xlarge",
2485+
"ml.g4dn.xlarge",
2486+
"ml.g4dn.2xlarge",
2487+
"ml.g4dn.4xlarge",
2488+
"ml.g4dn.8xlarge",
2489+
"ml.g4dn.16xlarge",
2490+
],
2491+
"model_kwargs": {},
2492+
"deploy_kwargs": {},
2493+
"predictor_specs": {
2494+
"supported_content_types": ["application/json"],
2495+
"supported_accept_types": ["application/json"],
2496+
"default_content_type": "application/json",
2497+
"default_accept_type": "application/json",
2498+
},
2499+
"inference_enable_network_isolation": True,
2500+
"validation_supported": False,
2501+
"fine_tuning_supported": False,
2502+
"resource_name_base": "sd-1-5-controlnet-1-1-fp16",
2503+
"default_payloads": {
2504+
"Dog": {
2505+
"content_type": "application/json",
2506+
"prompt_key": "hello.prompt",
2507+
"body": {
2508+
"hello": {"prompt": "a dog"},
2509+
"seed": 43,
2510+
},
2511+
}
2512+
},
2513+
"hosting_instance_type_variants": {
2514+
"regional_aliases": {
2515+
"af-south-1": {
2516+
"alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d"
2517+
"jl-inference:0.21.0-deepspeed0.8.3-cu117"
2518+
},
2519+
},
2520+
"variants": {
2521+
"c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2522+
"c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2523+
"c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2524+
"c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2525+
"c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2526+
"g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2527+
"g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2528+
"inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2529+
"inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2530+
"local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2531+
"local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2532+
"m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2533+
"m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2534+
"m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2535+
"p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2536+
"p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2537+
"p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2538+
"p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2539+
"p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2540+
"p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2541+
"r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2542+
"r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2543+
"t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2544+
"t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}},
2545+
},
2546+
},
2547+
},
23862548
"predictor-specs-model": {
23872549
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
23882550
"url": "https://huggingface.co/google/flan-t5-xxl",

tests/unit/sagemaker/jumpstart/test_payload_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,44 @@
1919
from sagemaker.jumpstart.payload_utils import (
2020
PayloadSerializer,
2121
_extract_generated_text_from_response,
22+
_construct_payload,
2223
)
2324
from sagemaker.jumpstart.types import JumpStartSerializablePayload
25+
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
2426

2527

26-
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec
28+
class TestConstructPayload(TestCase):
29+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
30+
def test_construct_payload(self, patched_get_model_specs):
31+
patched_get_model_specs.side_effect = get_special_model_spec
32+
33+
model_id = "prompt-key"
34+
region = "us-west-2"
35+
36+
constructed_payload_body = _construct_payload(
37+
prompt="kobebryant",
38+
model_id=model_id,
39+
model_version="*",
40+
region=region,
41+
).body
42+
43+
self.assertEqual(
44+
{
45+
"hello": {"prompt": "kobebryant"},
46+
"seed": 43,
47+
},
48+
constructed_payload_body,
49+
)
50+
51+
# Unsupported model
52+
self.assertIsNone(
53+
_construct_payload(
54+
prompt="blah",
55+
model_id="default_payloads",
56+
model_version="*",
57+
region=region,
58+
)
59+
)
2760

2861

2962
class TestResponseExtraction(TestCase):

0 commit comments

Comments
 (0)