File tree Expand file tree Collapse file tree 2 files changed +27
-2
lines changed
tests/integ/sagemaker/jumpstart/model Expand file tree Collapse file tree 2 files changed +27
-2
lines changed Original file line number Diff line number Diff line change @@ -45,8 +45,8 @@ def retrieve_samples(
45
45
model_version (str): The version of the JumpStart model for which to retrieve
46
46
the model payloads.
47
47
serialize (bool): Whether to serialize byte-stream valued payloads by downloading
48
- binary files from s3 and applying encoding, or to keep payload in pre-serialized .
49
- state. Set this option to False if you only want to avoid s3 download of it you
48
+ binary files from s3 and applying encoding, or to keep payload in pre-serialized
49
+ state. Set this option to False if you want to avoid s3 downloads or if you
50
50
want to inspect the payload in a human-readable form. (Default: True).
51
51
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model
52
52
specifications should be tolerated without raising an exception. If ``False``, raises an
Original file line number Diff line number Diff line change @@ -86,6 +86,31 @@ def test_prepacked_jumpstart_model(setup):
86
86
assert response is not None
87
87
88
88
89
+ def test_default_payload_jumpstart_model (setup ):
90
+
91
+ ## DO NOT COMMIT THIS LINE
92
+ os .environ .update ({"AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" : "jumpstart-cache-alpha-us-west-2" })
93
+
94
+ model_id = "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16"
95
+
96
+ model = JumpStartModel (
97
+ model_id = model_id ,
98
+ role = get_sm_session ().get_caller_identity_arn (),
99
+ sagemaker_session = get_sm_session (),
100
+ )
101
+
102
+ default_payload = model .retrieve_default_payload ()
103
+
104
+ # uses ml.g5.8xlarge instance
105
+ predictor = model .deploy (
106
+ tags = [{"Key" : JUMPSTART_TAG , "Value" : os .environ [ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID ]}],
107
+ )
108
+
109
+ response = predictor .predict (default_payload )
110
+
111
+ assert response is not None
112
+
113
+
89
114
@pytest .mark .skipif (
90
115
tests .integ .test_region () not in GATED_INFERENCE_MODEL_SUPPORTED_REGIONS ,
91
116
reason = f"JumpStart gated inference models unavailable in { tests .integ .test_region ()} ." ,
You can’t perform that action at this time.
0 commit comments