@@ -50,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory):
50
50
key_prefix = "Imagenet/TFRecords/validation" ,
51
51
)
52
52
train_input = sagemaker_session .upload_data (
53
- path = local_path ,
54
- key_prefix = "integ-test-data/trcomp/tensorflow/imagenet/val" ,
53
+ path = local_path , key_prefix = "integ-test-data/trcomp/tensorflow/imagenet/val"
55
54
)
56
55
return train_input
57
56
@@ -149,11 +148,11 @@ def test_pytorch(
149
148
Test the PyTorch estimator
150
149
"""
151
150
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
152
- data_path = os .path .join (DATA_DIR , "huggingface_byoc" )
153
151
154
152
hf = PyTorch (
155
153
py_version = "py38" ,
156
- entry_point = os .path .join (data_path , "run_glue.py" ),
154
+ source_dir = os .path .join (DATA_DIR , "huggingface_byoc" ),
155
+ entry_point = "run_glue.py" ,
157
156
role = "SageMakerRole" ,
158
157
framework_version = pytorch_training_compiler_latest_version ,
159
158
instance_count = instance_count ,
@@ -217,10 +216,7 @@ def test_huggingface_tensorflow(
217
216
218
217
@pytest .mark .release
219
218
def test_tensorflow (
220
- sagemaker_session ,
221
- gpu_instance_type ,
222
- tensorflow_training_latest_version ,
223
- imagenet_val_set ,
219
+ sagemaker_session , gpu_instance_type , tensorflow_training_latest_version , imagenet_val_set
224
220
):
225
221
"""
226
222
Test the TensorFlow estimator
@@ -272,8 +268,4 @@ def test_tensorflow(
272
268
compiler_config = TFTrainingCompilerConfig (),
273
269
)
274
270
275
- tf .fit (
276
- inputs = imagenet_val_set ,
277
- logs = True ,
278
- wait = True ,
279
- )
271
+ tf .fit (inputs = imagenet_val_set , logs = True , wait = True )
0 commit comments