@@ -212,18 +212,21 @@ def test_mnist_async(
212
212
sagemaker_session ,
213
213
cpu_instance_type ,
214
214
tf_full_version ,
215
- tensorflow_training_latest_version ,
216
215
tf_full_py_version
217
216
):
218
217
219
218
# Use the latest patch version for training, if available
220
- tf_full_v = Version (tf_full_version )
221
- tf_training_latest_v = Version (tensorflow_training_latest_version )
219
+ # tf_full_v = Version(tf_full_version)
220
+ # tf_training_latest_v = Version(tensorflow_training_latest_version)
221
+ #
222
+ # if (tf_full_v.major, tf_full_v.minor) == (tf_training_latest_v.major, tf_training_latest_v.minor):
223
+ # tf_fw_version = tensorflow_training_latest_version
224
+ # else:
225
+ # tf_fw_version = tf_full_version
222
226
223
- if (tf_full_v .major , tf_full_v .minor ) == (tf_training_latest_v .major , tf_training_latest_v .minor ):
224
- tf_fw_version = tensorflow_training_latest_version
225
- else :
226
- tf_fw_version = tf_full_version
227
+ # test
228
+ if tf_full_version == "2.7.0" :
229
+ tf_full_version = "2.7.1"
227
230
228
231
estimator = TensorFlow (
229
232
entry_point = SCRIPT ,
@@ -232,7 +235,7 @@ def test_mnist_async(
232
235
instance_count = 1 ,
233
236
instance_type = "ml.c5.4xlarge" ,
234
237
sagemaker_session = sagemaker_session ,
235
- framework_version = tf_fw_version ,
238
+ framework_version = tf_full_version ,
236
239
py_version = tf_full_py_version ,
237
240
tags = TAGS ,
238
241
)
0 commit comments