@@ -105,15 +105,24 @@ def test_model_builder_negative_path(sagemaker_session):
105
105
PYTHON_VERSION_IS_NOT_310 ,
106
106
reason = "Testing Schema Builder Simplification feature" ,
107
107
)
108
- def test_model_builder_happy_path_with_task_provided (sagemaker_session , gpu_instance_type ):
109
- model_builder = ModelBuilder (model = "bert-base-uncased:fill-mask" )
108
+ @pytest .mark .parametrize (
109
+ "model_id, task_provided" ,
110
+ [
111
+ ("bert-base-uncased" , "fill-mask" ),
112
+ ("bert-large-uncased-whole-word-masking-finetuned-squad" , "question-answering" ),
113
+ ],
114
+ )
115
+ def test_model_builder_happy_path_with_task_provided (
116
+ model_id , task_provided , sagemaker_session , gpu_instance_type
117
+ ):
118
+ model_builder = ModelBuilder (model = f"{ model_id } :{ task_provided } " )
110
119
111
120
model = model_builder .build (sagemaker_session = sagemaker_session )
112
121
113
122
assert model is not None
114
123
assert model_builder .schema_builder is not None
115
124
116
- inputs , outputs = task .retrieve_local_schemas ("fill-mask" )
125
+ inputs , outputs = task .retrieve_local_schemas (task_provided )
117
126
assert model_builder .schema_builder .sample_input == inputs
118
127
assert model_builder .schema_builder .sample_output == outputs
119
128
0 commit comments