Skip to content

Commit 42a947d

Browse files
authored
Merge branch 'master' into feat/jumpstart-private-model-artifacts
2 parents 3b8e600 + 4befd93 commit 42a947d

File tree

8 files changed

+627
-12
lines changed

8 files changed

+627
-12
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## v2.194.0 (2023-10-19)
4+
5+
### Features
6+
7+
* Added register step in Jumpstart model
8+
* jumpstart instance specific metric definitions
9+
10+
### Bug Fixes and Other Changes
11+
12+
* Updates for DJL 0.24.0 Release
13+
* use getter for resource-metadata dict
14+
* add method to Model class to check if repack is needed
15+
316
## v2.193.0 (2023-10-18)
417

518
### Features

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.193.1.dev0
1+
2.194.1.dev0

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,13 @@ def schedule(
281281
Args:
282282
pipeline_name (str): The SageMaker Pipeline name that will be scheduled.
283283
schedule_expression (str): The expression that defines when the schedule runs. It supports
284-
at expression, rate expression and cron expression. See '''https://docs.aws.amazon.com\
285-
/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-\
286-
request-ScheduleExpression''' for more details.
284+
at expression, rate expression and cron expression. See the
285+
`CreateSchedule API
286+
<https://docs.aws.amazon.com/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedulerequest-ScheduleExpression>`_
287+
for more details.
287288
state (str): Specifies whether the schedule is enabled or disabled. Valid values are
288-
ENABLED and DISABLED. See '''https://docs.aws.amazon.com/scheduler/latest/APIReference\
289-
/API_CreateSchedule.html#scheduler-CreateSchedule-request-State'''
289+
ENABLED and DISABLED. See the `State request parameter
290+
<https://docs.aws.amazon.com/scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request-State>`_
290291
for more details. If not specified, it will default to ENABLED.
291292
start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin
292293
invoking its target. Depending on the schedule’s recurrence expression, invocations

src/sagemaker/jumpstart/payload_utils.py

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,181 @@
1414
from __future__ import absolute_import
1515
import base64
1616
import json
17-
from typing import Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818
import re
1919
import boto3
2020

2121
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
22-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
22+
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
23+
from sagemaker.jumpstart.constants import (
24+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
25+
JUMPSTART_DEFAULT_REGION_NAME,
26+
)
2327
from sagemaker.jumpstart.enums import MIMEType
2428
from sagemaker.jumpstart.types import JumpStartSerializablePayload
25-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
29+
from sagemaker.jumpstart.utils import (
30+
get_jumpstart_content_bucket,
31+
)
32+
from sagemaker.session import Session
33+
2634

2735
S3_BYTES_REGEX = r"^\$s3<(?P<s3_key>[a-zA-Z0-9-_/.]+)>$"
2836
S3_B64_STR_REGEX = r"\$s3_b64<(?P<s3_key>[a-zA-Z0-9-_/.]+)>"
2937

3038

39+
def _extract_field_from_json(
40+
json_input: dict,
41+
keys: List[str],
42+
) -> Any:
43+
"""Given a dictionary, returns value at specified keys.
44+
45+
Raises:
46+
KeyError: If a key cannot be found in the json input.
47+
"""
48+
curr_json = json_input
49+
for idx, key in enumerate(keys):
50+
if idx < len(keys) - 1:
51+
curr_json = curr_json[key]
52+
continue
53+
return curr_json[key]
54+
55+
56+
def _construct_payload(
57+
prompt: str,
58+
model_id: str,
59+
model_version: str,
60+
region: Optional[str] = None,
61+
tolerate_vulnerable_model: bool = False,
62+
tolerate_deprecated_model: bool = False,
63+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
64+
) -> Optional[JumpStartSerializablePayload]:
65+
"""Returns example payload from prompt.
66+
67+
Args:
68+
prompt (str): String-valued prompt to embed in payload.
69+
model_id (str): JumpStart model ID of the JumpStart model for which to construct
70+
the payload.
71+
model_version (str): Version of the JumpStart model for which to retrieve the
72+
payload.
73+
region (Optional[str]): Region for which to retrieve the
74+
payload. (Default: None).
75+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
76+
specifications should be tolerated (exception not raised). If False, raises an
77+
exception if the script used by this version of the model has dependencies with known
78+
security vulnerabilities. (Default: False).
79+
tolerate_deprecated_model (bool): True if deprecated versions of model
80+
specifications should be tolerated (exception not raised). If False, raises
81+
an exception if the version of the model is deprecated. (Default: False).
82+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
83+
object, used for SageMaker interactions. If not
84+
specified, one is created using the default AWS configuration
85+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
86+
Returns:
87+
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
88+
this feature is unavailable for the specified model.
89+
"""
90+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
91+
model_id=model_id,
92+
model_version=model_version,
93+
region=region,
94+
tolerate_vulnerable_model=tolerate_vulnerable_model,
95+
tolerate_deprecated_model=tolerate_deprecated_model,
96+
sagemaker_session=sagemaker_session,
97+
)
98+
if payloads is None or len(payloads) == 0:
99+
return None
100+
101+
payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]
102+
103+
prompt_key: Optional[str] = payload_to_use.prompt_key
104+
if prompt_key is None:
105+
return None
106+
107+
payload_body = payload_to_use.body
108+
prompt_key_split = prompt_key.split(".")
109+
for idx, prompt_key in enumerate(prompt_key_split):
110+
if idx < len(prompt_key_split) - 1:
111+
payload_body = payload_body[prompt_key]
112+
else:
113+
payload_body[prompt_key] = prompt
114+
115+
return payload_to_use
116+
117+
118+
def _extract_generated_text_from_response(
119+
response: dict,
120+
model_id: str,
121+
model_version: str,
122+
region: Optional[str] = None,
123+
tolerate_vulnerable_model: bool = False,
124+
tolerate_deprecated_model: bool = False,
125+
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126+
accept_type: Optional[str] = None,
127+
) -> str:
128+
"""Returns generated text extracted from full response payload.
129+
130+
Args:
131+
response (dict): Dictionary-valued response from which to extract
132+
generated text.
133+
model_id (str): JumpStart model ID of the JumpStart model from which to extract
134+
generated text.
135+
model_version (str): Version of the JumpStart model for which to extract generated
136+
text.
137+
region (Optional[str]): Region for which to extract generated
138+
text. (Default: None).
139+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
140+
specifications should be tolerated (exception not raised). If False, raises an
141+
exception if the script used by this version of the model has dependencies with known
142+
security vulnerabilities. (Default: False).
143+
tolerate_deprecated_model (bool): True if deprecated versions of model
144+
specifications should be tolerated (exception not raised). If False, raises
145+
an exception if the version of the model is deprecated. (Default: False).
146+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
147+
object, used for SageMaker interactions. If not
148+
specified, one is created using the default AWS configuration
149+
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150+
accept_type (Optional[str]): The accept type to optionally specify for the response.
151+
(Default: None).
152+
153+
Returns:
154+
str: extracted generated text from the endpoint response payload.
155+
156+
Raises:
157+
ValueError: If the model is invalid, the model does not support generated text extraction,
158+
or if the response is malformed.
159+
"""
160+
161+
if not isinstance(response, dict):
162+
raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}")
163+
164+
payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads(
165+
model_id=model_id,
166+
model_version=model_version,
167+
region=region,
168+
tolerate_vulnerable_model=tolerate_vulnerable_model,
169+
tolerate_deprecated_model=tolerate_deprecated_model,
170+
sagemaker_session=sagemaker_session,
171+
)
172+
if payloads is None or len(payloads) == 0:
173+
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
174+
175+
for payload in payloads.values():
176+
if accept_type is None or payload.accept == accept_type:
177+
generated_text_response_key: Optional[str] = payload.generated_text_response_key
178+
if generated_text_response_key is None:
179+
raise ValueError(
180+
f"Model ID '{model_id}' does not support generated text extraction."
181+
)
182+
183+
generated_text_response_key_split = generated_text_response_key.split(".")
184+
try:
185+
return _extract_field_from_json(response, generated_text_response_key_split)
186+
except KeyError:
187+
raise ValueError(f"Response is malformed: {response}")
188+
189+
raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.")
190+
191+
31192
class PayloadSerializer:
32193
"""Utility class for serializing payloads associated with JumpStart models.
33194

src/sagemaker/jumpstart/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,11 @@ class JumpStartSerializablePayload(JumpStartDataHolderType):
339339
"content_type",
340340
"accept",
341341
"body",
342+
"generated_text_response_key",
343+
"prompt_key",
342344
]
343345

344-
_non_serializable_slots = ["raw_payload"]
346+
_non_serializable_slots = ["raw_payload", "prompt_key"]
345347

346348
def __init__(self, spec: Optional[Dict[str, Any]]):
347349
"""Initializes a JumpStartSerializablePayload object from its json representation.
@@ -369,6 +371,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None:
369371
self.content_type = json_obj["content_type"]
370372
self.body = json_obj["body"]
371373
accept = json_obj.get("accept")
374+
self.generated_text_response_key = json_obj.get("generated_text_response_key")
375+
self.prompt_key = json_obj.get("prompt_key")
372376
if accept:
373377
self.accept = accept
374378

0 commit comments

Comments
 (0)