Skip to content

Commit bad3792

Browse files
committed
Separate out integ test
1 parent b6c5b6f commit bad3792

File tree

1 file changed

+58
-2
lines changed

1 file changed

+58
-2
lines changed

tests/integ/sagemaker/serve/test_schema_builder.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode(
152152
"question-answering",
153153
"ml.m5.xlarge",
154154
),
155-
("deepset/roberta-base-squad2", "question-answering", "ml.m5.xlarge"),
156-
("openai/whisper-large-v3", "automatic-speech-recognition", "ml.m5.xlarge")
155+
("deepset/roberta-base-squad2", "question-answering", "ml.m5.xlarge")
157156
],
158157
)
159158
def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
@@ -203,6 +202,63 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
203202
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
204203

205204

205+
@pytest.mark.skipif(
206+
PYTHON_VERSION_IS_NOT_310,
207+
reason="Testing Schema Builder Simplification feature - Remote Schema",
208+
)
209+
@pytest.mark.parametrize(
210+
"model_id, task_provided, instance_type_provided",
211+
[
212+
("openai/whisper-large-v3", "automatic-speech-recognition", "ml.m5.xlarge")
213+
],
214+
)
215+
def test_model_builder_happy_path_with_task_provided_remote_schema_mode_asr(
216+
model_id, task_provided, sagemaker_session, instance_type_provided
217+
):
218+
model_builder = ModelBuilder(
219+
model=model_id,
220+
model_metadata={"HF_TASK": task_provided},
221+
instance_type=instance_type_provided,
222+
)
223+
model = model_builder.build(sagemaker_session=sagemaker_session)
224+
225+
assert model is not None
226+
assert model_builder.schema_builder is not None
227+
228+
remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever()
229+
inputs, outputs = remote_hf_schema_helper.get_resolved_hf_schema_for_task(task_provided)
230+
assert model_builder.schema_builder.sample_input == inputs
231+
assert model_builder.schema_builder.sample_output == outputs
232+
233+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
234+
caught_ex = None
235+
try:
236+
iam_client = sagemaker_session.boto_session.client("iam")
237+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
238+
239+
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
240+
predictor = model.deploy(
241+
role=role_arn, instance_count=1, instance_type=instance_type_provided
242+
)
243+
244+
predicted_outputs = predictor.predict(inputs)
245+
assert predicted_outputs is not None
246+
247+
except Exception as e:
248+
caught_ex = e
249+
finally:
250+
cleanup_model_resources(
251+
sagemaker_session=model_builder.sagemaker_session,
252+
model_name=model.name,
253+
endpoint_name=model.endpoint_name,
254+
)
255+
if caught_ex:
256+
logger.exception(caught_ex)
257+
assert (
258+
False
259+
), f"{caught_ex} was thrown when running transformers sagemaker endpoint test"
260+
261+
206262
def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
207263
model_builder = ModelBuilder(
208264
model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}

0 commit comments

Comments
 (0)