1
- # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1
+ # Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
4
# may not use this file except in compliance with the License. A copy of
40
40
PS_TASK_1 = {'index' : 0 , 'type' : 'ps' }
41
41
PS_TASK_2 = {'index' : 1 , 'type' : 'ps' }
42
42
MODEL_DIR = 's3://bucket/prefix'
43
+ MODEL_DIR_CMD_LIST = ['--model_dir' , MODEL_DIR ]
43
44
REGION = 'us-west-2'
44
45
RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), '..' , 'resources' )
45
46
@@ -82,9 +83,8 @@ def test_is_host_master():
82
83
83
84
@patch ('sagemaker_containers.beta.framework.entry_point.run' )
84
85
def test_single_machine (run_module , single_machine_training_env ):
85
- training .train (single_machine_training_env )
86
- run_module .assert_called_with (MODULE_DIR , MODULE_NAME ,
87
- single_machine_training_env .to_cmd_args (),
86
+ training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
87
+ run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
88
88
single_machine_training_env .to_env_vars (),
89
89
runner = runner .ProcessRunnerType )
90
90
@@ -93,9 +93,8 @@ def test_single_machine(run_module, single_machine_training_env):
93
93
def test_train_horovod (run_module , single_machine_training_env ):
94
94
single_machine_training_env .additional_framework_parameters ['sagemaker_mpi_enabled' ] = True
95
95
96
- training .train (single_machine_training_env )
97
- run_module .assert_called_with (MODULE_DIR , MODULE_NAME ,
98
- single_machine_training_env .to_cmd_args (),
96
+ training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
97
+ run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
99
98
single_machine_training_env .to_env_vars (),
100
99
runner = runner .MPIRunnerType )
101
100
@@ -108,7 +107,7 @@ def test_train_horovod(run_module, single_machine_training_env):
108
107
@patch ('threading.Thread' , lambda target : target ())
109
108
@patch ('time.sleep' , MagicMock ())
110
109
def test_train_distributed_master (run , tf_server , cluster_spec , distributed_training_env ):
111
- training .train (distributed_training_env )
110
+ training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
112
111
113
112
cluster_spec .assert_called_with ({'worker' : ['host2:2222' ],
114
113
'master' : ['host1:2222' ],
@@ -126,8 +125,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
126
125
'"environment": "cloud", ' \
127
126
'"task": {"index": 0, "type": "master"}}'
128
127
129
- run .assert_called_with ('s3://my/bucket' , 'script_name' ,
130
- distributed_training_env .to_cmd_args (),
128
+ run .assert_called_with ('s3://my/bucket' , 'script_name' , MODEL_DIR_CMD_LIST ,
131
129
{'TF_CONFIG' : tf_config })
132
130
133
131
@@ -140,7 +138,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
140
138
def test_train_distributed_worker (run , tf_server , cluster_spec , distributed_training_env ):
141
139
distributed_training_env .current_host = HOST2
142
140
143
- training .train (distributed_training_env )
141
+ training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
144
142
145
143
cluster_spec .assert_called_with ({'worker' : ['host2:2222' ],
146
144
'master' : ['host1:2222' ],
@@ -158,8 +156,7 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
158
156
'"environment": "cloud", ' \
159
157
'"task": {"index": 0, "type": "worker"}}'
160
158
161
- run .assert_called_with ('s3://my/bucket' , 'script_name' ,
162
- distributed_training_env .to_cmd_args (),
159
+ run .assert_called_with ('s3://my/bucket' , 'script_name' , MODEL_DIR_CMD_LIST ,
163
160
{'TF_CONFIG' : tf_config })
164
161
165
162
@@ -168,9 +165,9 @@ def test_train_distributed_no_ps(run, distributed_training_env):
168
165
distributed_training_env .additional_framework_parameters [
169
166
training .SAGEMAKER_PARAMETER_SERVER_ENABLED ] = False
170
167
distributed_training_env .current_host = HOST2
171
- training .train (distributed_training_env )
168
+ training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
172
169
173
- run .assert_called_with (MODULE_DIR , MODULE_NAME , distributed_training_env . to_cmd_args () ,
170
+ run .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
174
171
distributed_training_env .to_env_vars (), runner = runner .ProcessRunnerType )
175
172
176
173
@@ -251,7 +248,7 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
251
248
training .main ()
252
249
read_hyperparameters .assert_called_once_with ()
253
250
training_env .assert_called_once_with (hyperparameters = {})
254
- train .assert_called_once_with (single_machine_training_env )
251
+ train .assert_called_once_with (single_machine_training_env , MODEL_DIR_CMD_LIST )
255
252
configure_s3_env .assert_called_once ()
256
253
257
254
@@ -276,10 +273,25 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
276
273
@patch ('sagemaker_containers.beta.framework.env.read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277
274
'_tuning_objective_metric' : 'auc' })
278
275
@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
279
- def test_main_tunning_model_dir (configure_s3_env , read_hyperparameters , training_env ,
280
- set_level , train , logger , single_machine_training_env ):
276
+ def test_main_tuning_model_dir (configure_s3_env , read_hyperparameters , training_env ,
277
+ set_level , train , logger , single_machine_training_env ):
281
278
training_env .return_value = single_machine_training_env
282
279
os .environ ['SAGEMAKER_REGION' ] = REGION
283
280
training .main ()
284
- expected_model_dir = os . path . join (MODEL_DIR , single_machine_training_env .job_name , 'checkpoints' )
281
+ expected_model_dir = '{}/{}/model' . format (MODEL_DIR , single_machine_training_env .job_name )
285
282
configure_s3_env .assert_called_once_with (expected_model_dir , REGION )
283
+
284
+
285
+ @patch ('sagemaker_tensorflow_container.training.logger' )
286
+ @patch ('sagemaker_tensorflow_container.training.train' )
287
+ @patch ('logging.Logger.setLevel' )
288
+ @patch ('sagemaker_containers.beta.framework.training_env' )
289
+ @patch ('sagemaker_containers.beta.framework.env.read_hyperparameters' , return_value = {'model_dir' : '/opt/ml/model' ,
290
+ '_tuning_objective_metric' : 'auc' })
291
+ @patch ('sagemaker_tensorflow_container.s3_utils.configure' )
292
+ def test_main_tuning_mpi_model_dir (configure_s3_env , read_hyperparameters , training_env ,
293
+ set_level , train , logger , single_machine_training_env ):
294
+ training_env .return_value = single_machine_training_env
295
+ os .environ ['SAGEMAKER_REGION' ] = REGION
296
+ training .main ()
297
+ configure_s3_env .assert_called_once_with ('/opt/ml/model' , REGION )
0 commit comments