Skip to content

Commit f0e6b35

Browse files
author
Chuyang Deng
committed
add py_version to fixture
1 parent b998482 commit f0e6b35

File tree

2 files changed

+19
-20
lines changed

2 files changed

+19
-20
lines changed

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def custom_bucket_name(boto_session):
133133
return "{}-{}-{}".format(CUSTOM_BUCKET_NAME_PREFIX, region, account)
134134

135135

136+
@pytest.fixture(scope="module")
137+
def py_version():
138+
return (
139+
"py37" if tf_full_version == TensorFlow._LATEST_1X_VERSION else tests.integ.PYTHON_VERSION
140+
)
141+
142+
136143
@pytest.fixture(scope="module", params=["4.0", "4.0.0", "4.1", "4.1.0", "5.0", "5.0.0"])
137144
def chainer_version(request):
138145
return request.param

tests/integ/test_tf_script_mode.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4040

4141

42-
def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_version):
42+
def test_mnist_with_checkpoint_config(
43+
sagemaker_session, instance_type, tf_full_version, py_version
44+
):
4345
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
4446
sagemaker_session.default_bucket(), sagemaker_timestamp()
4547
)
@@ -52,9 +54,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
5254
sagemaker_session=sagemaker_session,
5355
script_mode=True,
5456
framework_version=tf_full_version,
55-
py_version="py37"
56-
if tf_full_version == TensorFlow._LATEST_1X_VERSION
57-
else tests.integ.PYTHON_VERSION,
57+
py_version=py_version,
5858
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
5959
checkpoint_s3_uri=checkpoint_s3_uri,
6060
checkpoint_local_path=checkpoint_local_path,
@@ -84,7 +84,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
8484
assert actual_training_checkpoint_config == expected_training_checkpoint_config
8585

8686

87-
def test_server_side_encryption(sagemaker_session, tf_full_version):
87+
def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
8888
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
8989
output_path = os.path.join(
9090
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
@@ -99,9 +99,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version):
9999
sagemaker_session=sagemaker_session,
100100
script_mode=True,
101101
framework_version=tf_full_version,
102-
py_version="py37"
103-
if tf_full_version == TensorFlow._LATEST_1X_VERSION
104-
else tests.integ.PYTHON_VERSION,
102+
py_version=py_version,
105103
code_location=output_path,
106104
output_path=output_path,
107105
model_dir="/opt/ml/model",
@@ -128,16 +126,14 @@ def test_server_side_encryption(sagemaker_session, tf_full_version):
128126

129127

130128
@pytest.mark.canary_quick
131-
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version):
129+
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py_version):
132130
estimator = TensorFlow(
133131
entry_point=SCRIPT,
134132
role=ROLE,
135133
train_instance_count=2,
136134
train_instance_type=instance_type,
137135
sagemaker_session=sagemaker_session,
138-
py_version="py37"
139-
if tf_full_version == TensorFlow._LATEST_1X_VERSION
140-
else tests.integ.PYTHON_VERSION,
136+
py_version=py_version,
141137
script_mode=True,
142138
framework_version=tf_full_version,
143139
distributions=PARAMETER_SERVER_DISTRIBUTION,
@@ -155,15 +151,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version):
155151
)
156152

157153

158-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version):
154+
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
159155
estimator = TensorFlow(
160156
entry_point=SCRIPT,
161157
role=ROLE,
162158
train_instance_count=1,
163159
train_instance_type="ml.c5.4xlarge",
164-
py_version="py37"
165-
if tf_full_version == TensorFlow._LATEST_1X_VERSION
166-
else tests.integ.PYTHON_VERSION,
160+
py_version=py_version,
167161
sagemaker_session=sagemaker_session,
168162
script_mode=True,
169163
# testing py-sdk functionality, no need to run against all TF versions
@@ -199,16 +193,14 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version):
199193
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
200194

201195

202-
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version):
196+
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_version, py_version):
203197
estimator = TensorFlow(
204198
entry_point="training.py",
205199
source_dir=TFS_RESOURCE_PATH,
206200
role=ROLE,
207201
train_instance_count=1,
208202
train_instance_type=instance_type,
209-
py_version="py37"
210-
if tf_full_version == TensorFlow._LATEST_1X_VERSION
211-
else tests.integ.PYTHON_VERSION,
203+
py_version=py_version,
212204
sagemaker_session=sagemaker_session,
213205
script_mode=True,
214206
framework_version=tf_full_version,

0 commit comments

Comments
 (0)