21
21
from sagemaker_tensorflow_container .training import SAGEMAKER_PARAMETER_SERVER_ENABLED
22
22
23
23
24
- def test_mnist (sagemaker_session , ecr_image , instance_type ):
24
+ def test_mnist (sagemaker_session , ecr_image , instance_type , framework_version ):
25
25
resource_path = os .path .join (os .path .dirname (__file__ ), '../..' , 'resources' )
26
26
script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
27
27
estimator = TensorFlow (entry_point = script ,
@@ -30,7 +30,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type):
30
30
train_instance_count = 1 ,
31
31
sagemaker_session = sagemaker_session ,
32
32
image_name = ecr_image ,
33
- framework_version = '1.12.0' ,
33
+ framework_version = framework_version ,
34
34
py_version = 'py3' ,
35
35
base_job_name = 'test-sagemaker-mnist' )
36
36
inputs = estimator .sagemaker_session .upload_data (
@@ -40,7 +40,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type):
40
40
_assert_s3_file_exists (estimator .model_data )
41
41
42
42
43
- def test_distributed_mnist_no_ps (sagemaker_session , ecr_image , instance_type ):
43
+ def test_distributed_mnist_no_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
44
44
resource_path = os .path .join (os .path .dirname (__file__ ), '../..' , 'resources' )
45
45
script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
46
46
estimator = TensorFlow (entry_point = script ,
@@ -49,7 +49,7 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
49
49
train_instance_type = instance_type ,
50
50
sagemaker_session = sagemaker_session ,
51
51
image_name = ecr_image ,
52
- framework_version = '1.12.0' ,
52
+ framework_version = framework_version ,
53
53
py_version = 'py3' ,
54
54
base_job_name = 'test-tf-sm-distributed-mnist' )
55
55
inputs = estimator .sagemaker_session .upload_data (
@@ -59,7 +59,7 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
59
59
_assert_s3_file_exists (estimator .model_data )
60
60
61
61
62
- def test_distributed_mnist_ps (sagemaker_session , ecr_image , instance_type ):
62
+ def test_distributed_mnist_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
63
63
resource_path = os .path .join (os .path .dirname (__file__ ), '..' , '..' , 'resources' )
64
64
script = os .path .join (resource_path , 'mnist' , 'mnist_estimator.py' )
65
65
estimator = TensorFlow (entry_point = script ,
@@ -69,7 +69,7 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type):
69
69
train_instance_type = instance_type ,
70
70
sagemaker_session = sagemaker_session ,
71
71
image_name = ecr_image ,
72
- framework_version = '1.12.0' ,
72
+ framework_version = framework_version ,
73
73
py_version = 'py3' ,
74
74
base_job_name = 'test-tf-sm-distributed-mnist' )
75
75
inputs = estimator .sagemaker_session .upload_data (
@@ -80,7 +80,7 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type):
80
80
_assert_s3_file_exists (estimator .model_data )
81
81
82
82
83
- def test_s3_plugin (sagemaker_session , ecr_image , instance_type , region ):
83
+ def test_s3_plugin (sagemaker_session , ecr_image , instance_type , region , framework_version ):
84
84
resource_path = os .path .join (os .path .dirname (__file__ ), '..' , '..' , 'resources' )
85
85
script = os .path .join (resource_path , 'mnist' , 'mnist_estimator.py' )
86
86
estimator = TensorFlow (entry_point = script ,
@@ -91,7 +91,7 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region):
91
91
# Disable throttling for checkpoint and model saving
92
92
'throttle-secs' : 0 ,
93
93
# Without the patch training jobs would fail around 100th to
94
- # 150th steps
94
+ # 150th step
95
95
'max-steps' : 200 ,
96
96
# Large batch size would result in a larger checkpoint file
97
97
'batch-size' : 2048 ,
@@ -103,7 +103,7 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region):
103
103
train_instance_type = instance_type ,
104
104
sagemaker_session = sagemaker_session ,
105
105
image_name = ecr_image ,
106
- framework_version = '1.12.0' ,
106
+ framework_version = framework_version ,
107
107
py_version = 'py3' ,
108
108
base_job_name = 'test-tf-sm-s3-mnist' )
109
109
estimator .fit ('s3://sagemaker-sample-data-{}/tensorflow/mnist' .format (region ))
0 commit comments