@@ -70,6 +70,7 @@ def simple_training_env():
70
70
env .hosts = CURRENT_HOST
71
71
env .current_host = CURRENT_HOST
72
72
env .to_env_vars = lambda : {}
73
+ env .job_name = 'test-training-job'
73
74
return env
74
75
75
76
@@ -252,3 +253,33 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
252
253
training_env .assert_called_once_with (hyperparameters = {})
253
254
train .assert_called_once_with (single_machine_training_env )
254
255
configure_s3_env .assert_called_once ()
256
+
257
+
258
+ @patch ('sagemaker_tensorflow_container.training.logger' )
259
+ @patch ('sagemaker_tensorflow_container.training.train' )
260
+ @patch ('logging.Logger.setLevel' )
261
+ @patch ('sagemaker_containers.beta.framework.training_env' )
262
+ @patch ('sagemaker_containers.beta.framework.env.read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR })
263
+ @patch ('sagemaker_tensorflow_container.s3_utils.configure' )
264
+ def test_main_simple_training_model_dir (configure_s3_env , read_hyperparameters , training_env ,
265
+ set_level , train , logger , single_machine_training_env ):
266
+ training_env .return_value = single_machine_training_env
267
+ os .environ ['SAGEMAKER_REGION' ] = REGION
268
+ training .main ()
269
+ configure_s3_env .assert_called_once_with (MODEL_DIR , REGION )
270
+
271
+
272
+ @patch ('sagemaker_tensorflow_container.training.logger' )
273
+ @patch ('sagemaker_tensorflow_container.training.train' )
274
+ @patch ('logging.Logger.setLevel' )
275
+ @patch ('sagemaker_containers.beta.framework.training_env' )
276
+ @patch ('sagemaker_containers.beta.framework.env.read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277
+ '_tuning_objective_metric' : 'auc' })
278
+ @patch ('sagemaker_tensorflow_container.s3_utils.configure' )
279
+ def test_main_tunning_model_dir (configure_s3_env , read_hyperparameters , training_env ,
280
+ set_level , train , logger , single_machine_training_env ):
281
+ training_env .return_value = single_machine_training_env
282
+ os .environ ['SAGEMAKER_REGION' ] = REGION
283
+ training .main ()
284
+ expected_model_dir = os .path .join (MODEL_DIR , single_machine_training_env .job_name , 'checkpoints' )
285
+ configure_s3_env .assert_called_once_with (expected_model_dir , REGION )
0 commit comments