@@ -320,6 +320,7 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
320
320
== == == =
321
321
import json
322
322
import os
323
+ import subprocess
323
324
324
325
>> >> >> > Add distributed training support (#98)
325
326
from mock import MagicMock , patch
@@ -350,27 +351,24 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
350
351
351
352
@pytest .fixture
352
353
def distributed_training_env ():
353
- env = MagicMock ()
354
-
355
- env .module_dir = MODULE_DIR
356
- env .module_name = MODULE_NAME
357
- env .hyperparameters = {}
358
- env .log_level = LOG_LEVEL
359
- env .hosts = HOST_LIST
360
- env .current_host = CURRENT_HOST
361
- env .additional_framework_parameters = {
362
- training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True
363
- }
364
-
365
- return env
354
+ return MagicMock (module_dir = MODULE_DIR ,
355
+ user_entry_point = MODULE_NAME ,
356
+ hyperparameters = {},
357
+ log_level = LOG_LEVEL ,
358
+ hosts = HOST_LIST ,
359
+ current_host = CURRENT_HOST ,
360
+ to_env_vars = lambda : {},
361
+ additional_framework_parameters = {
362
+ training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True
363
+ })
366
364
367
365
368
366
@pytest .fixture
369
367
def single_machine_training_env ():
370
368
env = MagicMock ()
371
369
372
370
env .module_dir = MODULE_DIR
373
- env .module_name = MODULE_NAME
371
+ env .user_entry_point = MODULE_NAME
374
372
env .hyperparameters = {'model_dir ': MODEL_DIR }
375
373
env .log_level = LOG_LEVEL
376
374
@@ -383,51 +381,95 @@ def test_is_host_master():
383
381
assert training ._is_host_master (HOST_LIST , 'somehost ') is False
384
382
385
383
386
- @patch ('sagemaker_containers .beta .framework .modules . run_module ')
384
+ @patch ('sagemaker_containers .beta .framework .entry_point . run ')
387
385
def test_single_machine (run_module , single_machine_training_env ):
388
386
training .train (single_machine_training_env )
387
+ << << << < HEAD
389
388
run_module .assert_called_with (MODULE_DIR , single_machine_training_env .to_cmd_args (),
390
389
single_machine_training_env .to_env_vars (), MODULE_NAME )
391
390
<< << << < HEAD
392
391
>> >> >> > Scriptmode single machine training implementation (#78)
393
392
== == == =
393
+ == == == =
394
+ run_module .assert_called_with (MODULE_DIR , MODULE_NAME ,
395
+ single_machine_training_env .to_cmd_args (),
396
+ single_machine_training_env .to_env_vars ())
397
+ >> >> >> > Update sagemaker containers (#119)
394
398
395
399
396
- @patch ('sagemaker_tensorflow_container .training ._wait_until_master_is_down ')
397
- @patch ('sagemaker_tensorflow_container .training ._run_worker ')
398
- @patch ('sagemaker_tensorflow_container .training ._run_ps ')
399
- def test_train_distributed_master (run_ps ,
400
- run_worker ,
401
- wait_until_master_is_down ,
402
- distributed_training_env ):
400
+ @patch ('sagemaker_containers .beta .framework .entry_point .run ')
401
+ @patch ('time .sleep ', MagicMock ())
402
+ def test_train_distributed_master (run , distributed_training_env ):
403
403
training .train (distributed_training_env )
404
- run_ps .assert_called_with (distributed_training_env )
405
- run_worker .assert_called_with (distributed_training_env , install_module = False )
406
- wait_until_master_is_down .assert_not_called
407
404
405
+ ps_tf_config = '{"cluster ": {' \
406
+ '"master ": ["host1 :2222 "], ' \
407
+ '"ps ": ["host1 :2223 ", "host2 :2223 "], ' \
408
+ '"worker ": ["host2 :2222 "]}, ' \
409
+ '"environment ": "cloud ", ' \
410
+ '"task ": {"index ": 0 , "type": "ps "}}'
408
411
409
- @patch ('sagemaker_tensorflow_container .training ._wait_until_master_is_down ')
410
- @patch ('sagemaker_tensorflow_container .training ._run_worker ')
411
- @patch ('sagemaker_tensorflow_container .training ._run_ps ')
412
- def test_train_distributed_worker (run_ps ,
413
- run_worker ,
414
- wait_until_master_is_down ,
412
+ run .assert_any_call ('s3 :// my / bucket ', 'script_name ',
413
+ distributed_training_env .to_cmd_args (),
414
+ {'TF_CONFIG ': ps_tf_config })
415
+
416
+ master_tf_config = '{"cluster ": {' \
417
+ '"master ": ["host1 :2222 "], ' \
418
+ '"ps ": ["host1 :2223 ", "host2 :2223 "], ' \
419
+ '"worker ": ["host2 :2222 "]}, ' \
420
+ '"environment ": "cloud ", ' \
421
+ '"task ": {"index ": 0 , "type": "master "}}'
422
+
423
+ run .assert_called_with ('s3 :// my / bucket ', 'script_name ',
424
+ distributed_training_env .to_cmd_args (),
425
+ {
426
+ 'TF_CONFIG ': master_tf_config })
427
+
428
+
429
+ @patch ('subprocess .check_call ')
430
+ @patch ('time .sleep ', MagicMock ())
431
+ @patch ('sagemaker_containers .beta .framework .entry_point .run ')
432
+ def test_train_distributed_worker (run ,
433
+ check_call ,
415
434
distributed_training_env ):
416
435
distributed_training_env .current_host = HOST2
436
+ check_call .side_effect = subprocess .CalledProcessError (returncode = 1 , cmd = [])
437
+
417
438
training .train (distributed_training_env )
418
- run_ps .assert_called_with (distributed_training_env )
419
- run_worker .assert_called_with (distributed_training_env , install_module = False )
420
- wait_until_master_is_down .assert_called_with (HOST1 )
421
439
440
+ ps_tf_config = '{"cluster ": {' \
441
+ '"master ": ["host1 :2222 "], ' \
442
+ '"ps ": ["host1 :2223 ", "host2 :2223 "], ' \
443
+ '"worker ": ["host2 :2222 "]}, ' \
444
+ '"environment ": "cloud ", ' \
445
+ '"task ": {"index ": 1 , "type": "ps "}}'
446
+
447
+ run .assert_any_call ('s3 :// my / bucket ', 'script_name ',
448
+ distributed_training_env .to_cmd_args (),
449
+ {'TF_CONFIG ': ps_tf_config })
450
+
451
+ master_tf_config = '{"cluster ": {' \
452
+ '"master ": ["host1 :2222 "], ' \
453
+ '"ps ": ["host1 :2223 ", "host2 :2223 "], ' \
454
+ '"worker ": ["host2 :2222 "]}, ' \
455
+ '"environment ": "cloud ", ' \
456
+ '"task ": {"index ": 0 , "type": "worker "}}'
422
457
423
- @patch ('sagemaker_containers .beta .framework .modules .run_module ')
424
- def test_train_distributed_no_ps (run_module , distributed_training_env ):
458
+ run .assert_called_with ('s3 :// my / bucket ', 'script_name ',
459
+ distributed_training_env .to_cmd_args (),
460
+ {
461
+ 'TF_CONFIG ': master_tf_config })
462
+
463
+
464
+ @patch ('sagemaker_containers .beta .framework .entry_point .run ')
465
+ def test_train_distributed_no_ps (run , distributed_training_env ):
425
466
distributed_training_env .additional_framework_parameters [
426
467
training .SAGEMAKER_PARAMETER_SERVER_ENABLED ] = False
427
468
distributed_training_env .current_host = HOST2
428
469
training .train (distributed_training_env )
429
- run_module .assert_called_with (MODULE_DIR , distributed_training_env .to_cmd_args (),
430
- distributed_training_env .to_env_vars (), MODULE_NAME )
470
+
471
+ run .assert_called_with (MODULE_DIR , MODULE_NAME , distributed_training_env .to_cmd_args (),
472
+ distributed_training_env .to_env_vars ())
431
473
432
474
433
475
@patch ('sagemaker_tensorflow_container .training ._build_tf_config ')
@@ -441,61 +483,26 @@ def test_get_env_vars_with_tf_config(build_tf_config, distributed_training_env):
441
483
hosts = HOST_LIST , current_host = CURRENT_HOST , ps_task = True )
442
484
443
485
444
- @patch ('sagemaker_containers .beta .framework .modules . run_module ')
486
+ @patch ('sagemaker_containers .beta .framework .entry_point . run ')
445
487
@patch ('sagemaker_tensorflow_container .training ._env_vars_with_tf_config ')
446
- def test_run_ps (env_vars_with_tf_config , run_module , distributed_training_env ):
447
- env_vars_with_tf_config .return_value = {}
448
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
488
+ def test_run_ps (env_vars_with_tf_config , run , distributed_training_env ):
449
489
training ._run_ps (distributed_training_env )
450
490
env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = True )
451
- run_module .assert_called_once_with (distributed_training_env .module_dir ,
452
- CMD_ARGS ,
453
- {},
454
- distributed_training_env .module_name ,
455
- wait = False )
456
-
457
491
458
- @patch ('sagemaker_containers .beta .framework .modules .write_env_vars ')
459
- @patch ('sagemaker_containers .beta .framework .modules .run ')
460
- @patch ('sagemaker_tensorflow_container .training ._env_vars_with_tf_config ')
461
- def test_run_worker_no_install (get_env_vars_with_tf_config ,
462
- run ,
463
- write_env_vars ,
464
- distributed_training_env ):
465
- get_env_vars_with_tf_config .return_value = {}
466
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
467
- training ._run_worker (distributed_training_env , install_module = False )
468
- get_env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = False )
469
- write_env_vars .assert_called_once_with ({})
470
- run .assert_called_once_with (distributed_training_env .module_name ,
471
- CMD_ARGS ,
472
- {})
473
-
474
-
475
- @patch ('sagemaker_containers .beta .framework .modules .run_module ')
476
- @patch ('sagemaker_tensorflow_container .training ._env_vars_with_tf_config ')
477
- def test_run_worker_install (get_env_vars_with_tf_config ,
478
- run_module ,
479
- distributed_training_env ):
480
- get_env_vars_with_tf_config .return_value = {}
481
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
482
- training ._run_worker (distributed_training_env , install_module = True )
483
- get_env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = False )
484
- run_module .assert_called_once_with (distributed_training_env .module_dir ,
485
- CMD_ARGS ,
486
- {},
487
- distributed_training_env .module_name )
492
+ run .assert_called_once_with (distributed_training_env .module_dir ,
493
+ distributed_training_env .user_entry_point ,
494
+ distributed_training_env .to_cmd_args (), env_vars_with_tf_config ())
488
495
489
496
490
497
def test_build_tf_config ():
491
- assert training ._build_tf_config (HOST_LIST , HOST1 ) == \
492
- {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': MASTER_TASK }
498
+ assert training ._build_tf_config (HOST_LIST , HOST1 ) == \
499
+ {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': MASTER_TASK }
493
500
assert training ._build_tf_config (HOST_LIST , HOST1 , ps_task = True ) == \
494
- {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': PS_TASK_1 }
495
- assert training ._build_tf_config (HOST_LIST , HOST2 ) == \
496
- {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': WORKER_TASK }
501
+ {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': PS_TASK_1 }
502
+ assert training ._build_tf_config (HOST_LIST , HOST2 ) == \
503
+ {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': WORKER_TASK }
497
504
assert training ._build_tf_config (HOST_LIST , HOST2 , ps_task = True ) == \
498
- {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': PS_TASK_2 }
505
+ {'cluster ': CLUSTER_WITH_PS , 'environment ': 'cloud ', 'task ': PS_TASK_2 }
499
506
500
507
501
508
def test_build_tf_config_error ():
0 commit comments