17
17
import boto3
18
18
import pytest
19
19
from sagemaker .tensorflow import TensorFlow
20
+ from sagemaker .tuner import HyperparameterTuner , IntegerParameter
20
21
from six .moves .urllib .parse import urlparse
21
22
22
23
from test .integration .utils import processor , py_version , unique_name_from_base # noqa: F401
24
+ from timeout import timeout
23
25
24
26
25
27
@pytest .mark .deploy_test
26
28
def test_mnist (sagemaker_session , ecr_image , instance_type , framework_version ):
27
- resource_path = os .path .join (os .path .dirname (__file__ ), '../ ..' , 'resources' )
29
+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , ' ..' , 'resources' )
28
30
script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
29
31
estimator = TensorFlow (entry_point = script ,
30
32
role = 'SageMakerRole' ,
@@ -42,7 +44,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
42
44
43
45
44
46
def test_distributed_mnist_no_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
45
- resource_path = os .path .join (os .path .dirname (__file__ ), '../ ..' , 'resources' )
47
+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , ' ..' , 'resources' )
46
48
script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
47
49
estimator = TensorFlow (entry_point = script ,
48
50
role = 'SageMakerRole' ,
@@ -110,6 +112,40 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
110
112
_assert_checkpoint_exists (region , estimator .model_dir , 200 )
111
113
112
114
115
+ def test_tuning (sagemaker_session , ecr_image , instance_type , framework_version ):
116
+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , '..' , 'resources' )
117
+ script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
118
+
119
+ estimator = TensorFlow (entry_point = script ,
120
+ role = 'SageMakerRole' ,
121
+ train_instance_type = instance_type ,
122
+ train_instance_count = 1 ,
123
+ sagemaker_session = sagemaker_session ,
124
+ image_name = ecr_image ,
125
+ framework_version = framework_version ,
126
+ script_mode = True )
127
+
128
+ hyperparameter_ranges = {'epochs' : IntegerParameter (1 , 2 )}
129
+ objective_metric_name = 'accuracy'
130
+ metric_definitions = [{'Name' : objective_metric_name , 'Regex' : 'accuracy = ([0-9\\ .]+)' }]
131
+
132
+ tuner = HyperparameterTuner (estimator ,
133
+ objective_metric_name ,
134
+ hyperparameter_ranges ,
135
+ metric_definitions ,
136
+ max_jobs = 2 ,
137
+ max_parallel_jobs = 2 )
138
+
139
+ with timeout (minutes = 20 ):
140
+ inputs = estimator .sagemaker_session .upload_data (
141
+ path = os .path .join (resource_path , 'mnist' , 'data' ),
142
+ key_prefix = 'scriptmode/mnist' )
143
+
144
+ tuning_job_name = unique_name_from_base ('test-tf-sm-tuning' , max_length = 32 )
145
+ tuner .fit (inputs , job_name = tuning_job_name )
146
+ tuner .wait ()
147
+
148
+
113
149
def _assert_checkpoint_exists (region , model_dir , checkpoint_number ):
114
150
_assert_s3_file_exists (region , os .path .join (model_dir , 'graph.pbtxt' ))
115
151
_assert_s3_file_exists (region ,
0 commit comments