Skip to content

Commit 5dc4aac

Browse files
committed
fix: docstring, add integ test
1 parent 9750dc5 commit 5dc4aac

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/sagemaker/payloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def retrieve_samples(
4545
model_version (str): The version of the JumpStart model for which to retrieve
4646
the model payloads.
4747
serialize (bool): Whether to serialize byte-stream valued payloads by downloading
48-
binary files from s3 and applying encoding, or to keep payload in pre-serialized .
49-
state. Set this option to False if you only want to avoid s3 download of it you
48+
binary files from s3 and applying encoding, or to keep payload in pre-serialized
49+
state. Set this option to False if you want to avoid s3 downloads or if you
5050
want to inspect the payload in a human-readable form. (Default: True).
5151
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
5252
specifications should be tolerated without raising an exception. If ``False``, raises an

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,31 @@ def test_prepacked_jumpstart_model(setup):
8686
assert response is not None
8787

8888

89+
def test_default_payload_jumpstart_model(setup):
90+
91+
## DO NOT COMMIT THIS LINE
92+
os.environ.update({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE": "jumpstart-cache-alpha-us-west-2"})
93+
94+
model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16"
95+
96+
model = JumpStartModel(
97+
model_id=model_id,
98+
role=get_sm_session().get_caller_identity_arn(),
99+
sagemaker_session=get_sm_session(),
100+
)
101+
102+
default_payload = model.retrieve_default_payload()
103+
104+
# uses ml.g5.8xlarge instance
105+
predictor = model.deploy(
106+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
107+
)
108+
109+
response = predictor.predict(default_payload)
110+
111+
assert response is not None
112+
113+
89114
@pytest.mark.skipif(
90115
tests.integ.test_region() not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS,
91116
reason=f"JumpStart gated inference models unavailable in {tests.integ.test_region()}.",

0 commit comments

Comments
 (0)