@@ -152,8 +152,7 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode(
152
152
"question-answering" ,
153
153
"ml.m5.xlarge" ,
154
154
),
155
- ("deepset/roberta-base-squad2" , "question-answering" , "ml.m5.xlarge" ),
156
- ("openai/whisper-large-v3" , "automatic-speech-recognition" , "ml.m5.xlarge" )
155
+ ("deepset/roberta-base-squad2" , "question-answering" , "ml.m5.xlarge" )
157
156
],
158
157
)
159
158
def test_model_builder_happy_path_with_task_provided_remote_schema_mode (
@@ -203,6 +202,63 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
203
202
), f"{ caught_ex } was thrown when running transformers sagemaker endpoint test"
204
203
205
204
205
+ @pytest .mark .skipif (
206
+ PYTHON_VERSION_IS_NOT_310 ,
207
+ reason = "Testing Schema Builder Simplification feature - Remote Schema" ,
208
+ )
209
+ @pytest .mark .parametrize (
210
+ "model_id, task_provided, instance_type_provided" ,
211
+ [
212
+ ("openai/whisper-large-v3" , "automatic-speech-recognition" , "ml.m5.xlarge" )
213
+ ],
214
+ )
215
+ def test_model_builder_happy_path_with_task_provided_remote_schema_mode_asr (
216
+ model_id , task_provided , sagemaker_session , instance_type_provided
217
+ ):
218
+ model_builder = ModelBuilder (
219
+ model = model_id ,
220
+ model_metadata = {"HF_TASK" : task_provided },
221
+ instance_type = instance_type_provided ,
222
+ )
223
+ model = model_builder .build (sagemaker_session = sagemaker_session )
224
+
225
+ assert model is not None
226
+ assert model_builder .schema_builder is not None
227
+
228
+ remote_hf_schema_helper = remote_schema_retriever .RemoteSchemaRetriever ()
229
+ inputs , outputs = remote_hf_schema_helper .get_resolved_hf_schema_for_task (task_provided )
230
+ assert model_builder .schema_builder .sample_input == inputs
231
+ assert model_builder .schema_builder .sample_output == outputs
232
+
233
+ with timeout (minutes = SERVE_SAGEMAKER_ENDPOINT_TIMEOUT ):
234
+ caught_ex = None
235
+ try :
236
+ iam_client = sagemaker_session .boto_session .client ("iam" )
237
+ role_arn = iam_client .get_role (RoleName = "SageMakerRole" )["Role" ]["Arn" ]
238
+
239
+ logger .info ("Deploying and predicting in SAGEMAKER_ENDPOINT mode..." )
240
+ predictor = model .deploy (
241
+ role = role_arn , instance_count = 1 , instance_type = instance_type_provided
242
+ )
243
+
244
+ predicted_outputs = predictor .predict (inputs )
245
+ assert predicted_outputs is not None
246
+
247
+ except Exception as e :
248
+ caught_ex = e
249
+ finally :
250
+ cleanup_model_resources (
251
+ sagemaker_session = model_builder .sagemaker_session ,
252
+ model_name = model .name ,
253
+ endpoint_name = model .endpoint_name ,
254
+ )
255
+ if caught_ex :
256
+ logger .exception (caught_ex )
257
+ assert (
258
+ False
259
+ ), f"{ caught_ex } was thrown when running transformers sagemaker endpoint test"
260
+
261
+
206
262
def test_model_builder_negative_path_with_invalid_task (sagemaker_session ):
207
263
model_builder = ModelBuilder (
208
264
model = "bert-base-uncased" , model_metadata = {"HF_TASK" : "invalid-task" }
0 commit comments