Skip to content

Commit 44f3850

Browse files
committed
infra: use generated TensorFlow version fixtures
1 parent 910eebd commit 44f3850

15 files changed

+221
-265
lines changed

tests/conftest.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,18 @@ def xgboost_version(request):
164164
return request.param
165165

166166

167-
@pytest.fixture(scope="module")
168-
def tf_version(tensorflow_training_version):
169-
# TODO: remove this fixture and update tests
170-
if tensorflow_training_version in ("1.13.1", "2.2", "2.2.0"):
171-
pytest.skip("version isn't compatible with both training and inference.")
172-
return tensorflow_training_version
167+
@pytest.fixture(scope="module", params=["py2", "py3"])
168+
def tensorflow_training_py_version(tensorflow_training_version, request):
169+
return _tf_py_version(tensorflow_training_version, request)
173170

174171

175172
@pytest.fixture(scope="module", params=["py2", "py3"])
176-
def tf_py_version(tensorflow_training_version, request):
177-
version = Version(tensorflow_training_version)
173+
def tensorflow_inference_py_version(tensorflow_inference_version, request):
174+
return _tf_py_version(tensorflow_inference_version, request)
175+
176+
177+
def _tf_py_version(tf_version, request):
178+
version = Version(tf_version)
178179
if version < Version("1.11"):
179180
return "py2"
180181
if version < Version("2.2"):
@@ -253,28 +254,18 @@ def sklearn_full_py_version():
253254

254255

255256
@pytest.fixture(scope="module")
256-
def tf_training_latest_version():
257-
return "2.2.0"
258-
259-
260-
@pytest.fixture(scope="module")
261-
def tf_training_latest_py_version():
262-
return "py37"
263-
264-
265-
@pytest.fixture(scope="module")
266-
def tf_serving_latest_version():
267-
return "2.1.0"
268-
269-
270-
@pytest.fixture(scope="module")
271-
def tf_full_version(tf_training_latest_version, tf_serving_latest_version):
257+
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
272258
"""Fixture for TF tests that test both training and inference.
273259
274260
Fixture exists as such, since TF training and TFS have different latest versions.
275261
Otherwise, this would simply be a single latest version.
276262
"""
277-
return str(min(Version(tf_training_latest_version), Version(tf_serving_latest_version)))
263+
return str(
264+
min(
265+
Version(tensorflow_training_latest_version),
266+
Version(tensorflow_inference_latest_version),
267+
)
268+
)
278269

279270

280271
@pytest.fixture(scope="module")
@@ -292,11 +283,6 @@ def tf_full_py_version(tf_full_version):
292283
return "py37"
293284

294285

295-
@pytest.fixture(scope="module")
296-
def ei_tf_full_version():
297-
return "2.0.0"
298-
299-
300286
@pytest.fixture(scope="module")
301287
def xgboost_full_version():
302288
return "1.0-1"

tests/integ/test_airflow_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,10 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
512512

513513
@pytest.mark.canary_quick
514514
def test_tf_airflow_config_uploads_data_source_to_s3(
515-
sagemaker_session, cpu_instance_type, tf_training_latest_version, tf_training_latest_py_version
515+
sagemaker_session,
516+
cpu_instance_type,
517+
tensorflow_training_latest_version,
518+
tensorflow_training_latest_py_version,
516519
):
517520
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
518521
tf = TensorFlow(
@@ -524,8 +527,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
524527
instance_count=SINGLE_INSTANCE_COUNT,
525528
instance_type=cpu_instance_type,
526529
sagemaker_session=sagemaker_session,
527-
framework_version=tf_training_latest_version,
528-
py_version=tf_training_latest_py_version,
530+
framework_version=tensorflow_training_latest_version,
531+
py_version=tensorflow_training_latest_py_version,
529532
metric_definitions=[
530533
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
531534
],

tests/integ/test_data_capture_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_serving_latest_version
44+
sagemaker_session, tensorflow_inference_latest_version
4545
):
4646
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4747
model_data = sagemaker_session.upload_data(
@@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5252
model = TensorFlowModel(
5353
model_data=model_data,
5454
role=ROLE,
55-
framework_version=tf_serving_latest_version,
55+
framework_version=tensorflow_inference_latest_version,
5656
sagemaker_session=sagemaker_session,
5757
)
5858
predictor = model.deploy(
@@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
9898

9999

100100
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_serving_latest_version
101+
sagemaker_session, tensorflow_inference_latest_version
102102
):
103103
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104104
model_data = sagemaker_session.upload_data(
@@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
109109
model = TensorFlowModel(
110110
model_data=model_data,
111111
role=ROLE,
112-
framework_version=tf_serving_latest_version,
112+
framework_version=tensorflow_inference_latest_version,
113113
sagemaker_session=sagemaker_session,
114114
)
115115
destination_s3_uri = os.path.join(
@@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184184

185185

186186
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_serving_latest_version
187+
sagemaker_session, tensorflow_inference_latest_version
188188
):
189189
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190190
model_data = sagemaker_session.upload_data(
@@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
195195
model = TensorFlowModel(
196196
model_data=model_data,
197197
role=ROLE,
198-
framework_version=tf_serving_latest_version,
198+
framework_version=tensorflow_inference_latest_version,
199199
sagemaker_session=sagemaker_session,
200200
)
201201
destination_s3_uri = os.path.join(

tests/integ/test_horovod.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@
3131
@pytest.mark.canary_quick
3232
def test_hvd_cpu(
3333
sagemaker_session,
34-
tf_training_latest_version,
35-
tf_training_latest_py_version,
34+
tensorflow_training_latest_version,
35+
tensorflow_training_latest_py_version,
3636
cpu_instance_type,
3737
tmpdir,
3838
):
3939
_create_and_fit_estimator(
4040
sagemaker_session,
41-
tf_training_latest_version,
42-
tf_training_latest_py_version,
41+
tensorflow_training_latest_version,
42+
tensorflow_training_latest_py_version,
4343
cpu_instance_type,
4444
tmpdir,
4545
)
@@ -50,12 +50,15 @@ def test_hvd_cpu(
5050
integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region"
5151
)
5252
def test_hvd_gpu(
53-
sagemaker_session, tf_training_latest_version, tf_training_latest_py_version, tmpdir
53+
sagemaker_session,
54+
tensorflow_training_latest_version,
55+
tensorflow_training_latest_py_version,
56+
tmpdir,
5457
):
5558
_create_and_fit_estimator(
5659
sagemaker_session,
57-
tf_training_latest_version,
58-
tf_training_latest_py_version,
60+
tensorflow_training_latest_version,
61+
tensorflow_training_latest_py_version,
5962
"ml.p2.xlarge",
6063
tmpdir,
6164
)
@@ -65,8 +68,8 @@ def test_hvd_gpu(
6568
@pytest.mark.parametrize("instances, processes", [[1, 2], (2, 1), (2, 2)])
6669
def test_horovod_local_mode(
6770
sagemaker_local_session,
68-
tf_training_latest_version,
69-
tf_training_latest_py_version,
71+
tensorflow_training_latest_version,
72+
tensorflow_training_latest_py_version,
7073
instances,
7174
processes,
7275
tmpdir,
@@ -80,8 +83,8 @@ def test_horovod_local_mode(
8083
instance_type="local",
8184
sagemaker_session=sagemaker_local_session,
8285
output_path=output_path,
83-
framework_version=tf_training_latest_version,
84-
py_version=tf_training_latest_py_version,
86+
framework_version=tensorflow_training_latest_version,
87+
py_version=tensorflow_training_latest_py_version,
8588
distribution={"mpi": {"enabled": True, "processes_per_host": processes}},
8689
)
8790

tests/integ/test_model_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_serving_latest_version):
91+
def predictor(sagemaker_session, tensorflow_inference_latest_version):
9292
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9393
model_data = sagemaker_session.upload_data(
9494
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_serving_latest_version):
100100
model = TensorFlowModel(
101101
model_data=model_data,
102102
role=ROLE,
103-
framework_version=tf_serving_latest_version,
103+
framework_version=tensorflow_inference_latest_version,
104104
sagemaker_session=sagemaker_session,
105105
)
106106
predictor = model.deploy(

tests/integ/test_tf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939

4040

4141
def test_mnist_with_checkpoint_config(
42-
sagemaker_session, instance_type, tf_training_latest_version, tf_training_latest_py_version
42+
sagemaker_session,
43+
instance_type,
44+
tensorflow_training_latest_version,
45+
tensorflow_training_latest_py_version,
4346
):
4447
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
4548
sagemaker_session.default_bucket(), sagemaker_timestamp()
@@ -51,8 +54,8 @@ def test_mnist_with_checkpoint_config(
5154
instance_count=1,
5255
instance_type=instance_type,
5356
sagemaker_session=sagemaker_session,
54-
framework_version=tf_training_latest_version,
55-
py_version=tf_training_latest_py_version,
57+
framework_version=tensorflow_training_latest_version,
58+
py_version=tensorflow_training_latest_py_version,
5659
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
5760
checkpoint_s3_uri=checkpoint_s3_uri,
5861
checkpoint_local_path=checkpoint_local_path,
@@ -124,16 +127,19 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
124127

125128
@pytest.mark.canary_quick
126129
def test_mnist_distributed(
127-
sagemaker_session, instance_type, tf_training_latest_version, tf_training_latest_py_version
130+
sagemaker_session,
131+
instance_type,
132+
tensorflow_training_latest_version,
133+
tensorflow_training_latest_py_version,
128134
):
129135
estimator = TensorFlow(
130136
entry_point=SCRIPT,
131137
role=ROLE,
132138
instance_count=2,
133139
instance_type=instance_type,
134140
sagemaker_session=sagemaker_session,
135-
framework_version=tf_training_latest_version,
136-
py_version=tf_training_latest_py_version,
141+
framework_version=tensorflow_training_latest_version,
142+
py_version=tensorflow_training_latest_py_version,
137143
distribution=PARAMETER_SERVER_DISTRIBUTION,
138144
)
139145
inputs = estimator.sagemaker_session.upload_data(

tests/integ/test_tf_efs_fsx.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_mnist_efs(
5858
efs_fsx_setup,
5959
sagemaker_session,
6060
cpu_instance_type,
61-
tf_training_latest_version,
62-
tf_training_latest_py_version,
61+
tensorflow_training_latest_version,
62+
tensorflow_training_latest_py_version,
6363
):
6464
role = efs_fsx_setup["role_name"]
6565
subnets = [efs_fsx_setup["subnet_id"]]
@@ -71,8 +71,8 @@ def test_mnist_efs(
7171
instance_count=1,
7272
instance_type=cpu_instance_type,
7373
sagemaker_session=sagemaker_session,
74-
framework_version=tf_training_latest_version,
75-
py_version=tf_training_latest_py_version,
74+
framework_version=tensorflow_training_latest_version,
75+
py_version=tensorflow_training_latest_py_version,
7676
subnets=subnets,
7777
security_group_ids=security_group_ids,
7878
)
@@ -103,8 +103,8 @@ def test_mnist_lustre(
103103
efs_fsx_setup,
104104
sagemaker_session,
105105
cpu_instance_type,
106-
tf_training_latest_version,
107-
tf_training_latest_py_version,
106+
tensorflow_training_latest_version,
107+
tensorflow_training_latest_py_version,
108108
):
109109
role = efs_fsx_setup["role_name"]
110110
subnets = [efs_fsx_setup["subnet_id"]]
@@ -116,8 +116,8 @@ def test_mnist_lustre(
116116
instance_count=1,
117117
instance_type=cpu_instance_type,
118118
sagemaker_session=sagemaker_session,
119-
framework_version=tf_training_latest_version,
120-
py_version=tf_training_latest_py_version,
119+
framework_version=tensorflow_training_latest_version,
120+
py_version=tensorflow_training_latest_py_version,
121121
subnets=subnets,
122122
security_group_ids=security_group_ids,
123123
)
@@ -144,8 +144,8 @@ def test_tuning_tf_efs(
144144
efs_fsx_setup,
145145
sagemaker_session,
146146
cpu_instance_type,
147-
tf_training_latest_version,
148-
tf_training_latest_py_version,
147+
tensorflow_training_latest_version,
148+
tensorflow_training_latest_py_version,
149149
):
150150
role = efs_fsx_setup["role_name"]
151151
subnets = [efs_fsx_setup["subnet_id"]]
@@ -157,8 +157,8 @@ def test_tuning_tf_efs(
157157
instance_count=1,
158158
instance_type=cpu_instance_type,
159159
sagemaker_session=sagemaker_session,
160-
framework_version=tf_training_latest_version,
161-
py_version=tf_training_latest_py_version,
160+
framework_version=tensorflow_training_latest_version,
161+
py_version=tensorflow_training_latest_py_version,
162162
subnets=subnets,
163163
security_group_ids=security_group_ids,
164164
)
@@ -197,8 +197,8 @@ def test_tuning_tf_lustre(
197197
efs_fsx_setup,
198198
sagemaker_session,
199199
cpu_instance_type,
200-
tf_training_latest_version,
201-
tf_training_latest_py_version,
200+
tensorflow_training_latest_version,
201+
tensorflow_training_latest_py_version,
202202
):
203203
role = efs_fsx_setup["role_name"]
204204
subnets = [efs_fsx_setup["subnet_id"]]
@@ -210,8 +210,8 @@ def test_tuning_tf_lustre(
210210
instance_count=1,
211211
instance_type=cpu_instance_type,
212212
sagemaker_session=sagemaker_session,
213-
framework_version=tf_training_latest_version,
214-
py_version=tf_training_latest_py_version,
213+
framework_version=tensorflow_training_latest_version,
214+
py_version=tensorflow_training_latest_py_version,
215215
subnets=subnets,
216216
security_group_ids=security_group_ids,
217217
)

0 commit comments

Comments
 (0)