39
39
TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
40
40
41
41
42
- def test_mnist_with_checkpoint_config (sagemaker_session , instance_type , tf_full_version ):
42
+ def test_mnist_with_checkpoint_config (
43
+ sagemaker_session , instance_type , tf_full_version , py_version
44
+ ):
43
45
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}" .format (
44
46
sagemaker_session .default_bucket (), sagemaker_timestamp ()
45
47
)
@@ -52,9 +54,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
52
54
sagemaker_session = sagemaker_session ,
53
55
script_mode = True ,
54
56
framework_version = tf_full_version ,
55
- py_version = "py37"
56
- if tf_full_version == TensorFlow ._LATEST_1X_VERSION
57
- else tests .integ .PYTHON_VERSION ,
57
+ py_version = py_version ,
58
58
metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
59
59
checkpoint_s3_uri = checkpoint_s3_uri ,
60
60
checkpoint_local_path = checkpoint_local_path ,
@@ -84,7 +84,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
84
84
assert actual_training_checkpoint_config == expected_training_checkpoint_config
85
85
86
86
87
- def test_server_side_encryption (sagemaker_session , tf_full_version ):
87
+ def test_server_side_encryption (sagemaker_session , tf_full_version , py_version ):
88
88
with kms_utils .bucket_with_encryption (sagemaker_session , ROLE ) as (bucket_with_kms , kms_key ):
89
89
output_path = os .path .join (
90
90
bucket_with_kms , "test-server-side-encryption" , time .strftime ("%y%m%d-%H%M" )
@@ -99,9 +99,7 @@ def test_server_side_encryption(sagemaker_session, tf_full_version):
99
99
sagemaker_session = sagemaker_session ,
100
100
script_mode = True ,
101
101
framework_version = tf_full_version ,
102
- py_version = "py37"
103
- if tf_full_version == TensorFlow ._LATEST_1X_VERSION
104
- else tests .integ .PYTHON_VERSION ,
102
+ py_version = py_version ,
105
103
code_location = output_path ,
106
104
output_path = output_path ,
107
105
model_dir = "/opt/ml/model" ,
@@ -128,16 +126,14 @@ def test_server_side_encryption(sagemaker_session, tf_full_version):
128
126
129
127
130
128
@pytest .mark .canary_quick
131
- def test_mnist_distributed (sagemaker_session , instance_type , tf_full_version ):
129
+ def test_mnist_distributed (sagemaker_session , instance_type , tf_full_version , py_version ):
132
130
estimator = TensorFlow (
133
131
entry_point = SCRIPT ,
134
132
role = ROLE ,
135
133
train_instance_count = 2 ,
136
134
train_instance_type = instance_type ,
137
135
sagemaker_session = sagemaker_session ,
138
- py_version = "py37"
139
- if tf_full_version == TensorFlow ._LATEST_1X_VERSION
140
- else tests .integ .PYTHON_VERSION ,
136
+ py_version = py_version ,
141
137
script_mode = True ,
142
138
framework_version = tf_full_version ,
143
139
distributions = PARAMETER_SERVER_DISTRIBUTION ,
@@ -155,15 +151,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version):
155
151
)
156
152
157
153
158
- def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version ):
154
+ def test_mnist_async (sagemaker_session , cpu_instance_type , tf_full_version , py_version ):
159
155
estimator = TensorFlow (
160
156
entry_point = SCRIPT ,
161
157
role = ROLE ,
162
158
train_instance_count = 1 ,
163
159
train_instance_type = "ml.c5.4xlarge" ,
164
- py_version = "py37"
165
- if tf_full_version == TensorFlow ._LATEST_1X_VERSION
166
- else tests .integ .PYTHON_VERSION ,
160
+ py_version = py_version ,
167
161
sagemaker_session = sagemaker_session ,
168
162
script_mode = True ,
169
163
# testing py-sdk functionality, no need to run against all TF versions
@@ -199,16 +193,14 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version):
199
193
_assert_model_name_match (sagemaker_session .sagemaker_client , endpoint_name , model_name )
200
194
201
195
202
- def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_full_version ):
196
+ def test_deploy_with_input_handlers (sagemaker_session , instance_type , tf_full_version , py_version ):
203
197
estimator = TensorFlow (
204
198
entry_point = "training.py" ,
205
199
source_dir = TFS_RESOURCE_PATH ,
206
200
role = ROLE ,
207
201
train_instance_count = 1 ,
208
202
train_instance_type = instance_type ,
209
- py_version = "py37"
210
- if tf_full_version == TensorFlow ._LATEST_1X_VERSION
211
- else tests .integ .PYTHON_VERSION ,
203
+ py_version = py_version ,
212
204
sagemaker_session = sagemaker_session ,
213
205
script_mode = True ,
214
206
framework_version = tf_full_version ,
0 commit comments