14
14
15
15
import json
16
16
import os
17
+ import subprocess
17
18
18
19
from mock import MagicMock , patch
19
20
import pytest
43
44
44
45
@pytest .fixture
45
46
def distributed_training_env ():
46
- env = MagicMock ()
47
-
48
- env .module_dir = MODULE_DIR
49
- env .module_name = MODULE_NAME
50
- env .hyperparameters = {}
51
- env .log_level = LOG_LEVEL
52
- env .hosts = HOST_LIST
53
- env .current_host = CURRENT_HOST
54
- env .additional_framework_parameters = {
55
- training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True
56
- }
57
-
58
- return env
47
+ return MagicMock (module_dir = MODULE_DIR ,
48
+ user_entry_point = MODULE_NAME ,
49
+ hyperparameters = {},
50
+ log_level = LOG_LEVEL ,
51
+ hosts = HOST_LIST ,
52
+ current_host = CURRENT_HOST ,
53
+ to_env_vars = lambda : {},
54
+ additional_framework_parameters = {
55
+ training .SAGEMAKER_PARAMETER_SERVER_ENABLED : True
56
+ })
59
57
60
58
61
59
@pytest .fixture
62
60
def single_machine_training_env ():
63
61
env = MagicMock ()
64
62
65
63
env .module_dir = MODULE_DIR
66
- env .module_name = MODULE_NAME
64
+ env .user_entry_point = MODULE_NAME
67
65
env .hyperparameters = {'model_dir' : MODEL_DIR }
68
66
env .log_level = LOG_LEVEL
69
67
@@ -76,48 +74,87 @@ def test_is_host_master():
76
74
assert training ._is_host_master (HOST_LIST , 'somehost' ) is False
77
75
78
76
79
- @patch ('sagemaker_containers.beta.framework.modules.run_module ' )
77
+ @patch ('sagemaker_containers.beta.framework.entry_point.run ' )
80
78
def test_single_machine (run_module , single_machine_training_env ):
81
79
training .train (single_machine_training_env )
82
- run_module .assert_called_with (MODULE_DIR , single_machine_training_env .to_cmd_args (),
83
- single_machine_training_env .to_env_vars (), MODULE_NAME )
80
+ run_module .assert_called_with (MODULE_DIR , MODULE_NAME ,
81
+ single_machine_training_env .to_cmd_args (),
82
+ single_machine_training_env .to_env_vars ())
84
83
85
84
86
- @patch ('sagemaker_tensorflow_container.training._wait_until_master_is_down' )
87
- @patch ('sagemaker_tensorflow_container.training._run_worker' )
88
- @patch ('sagemaker_tensorflow_container.training._run_ps' )
89
- def test_train_distributed_master (run_ps ,
90
- run_worker ,
91
- wait_until_master_is_down ,
92
- distributed_training_env ):
85
+ @patch ('sagemaker_containers.beta.framework.entry_point.run' )
86
+ @patch ('time.sleep' , MagicMock ())
87
+ def test_train_distributed_master (run , distributed_training_env ):
93
88
training .train (distributed_training_env )
94
- run_ps .assert_called_with (distributed_training_env )
95
- run_worker .assert_called_with (distributed_training_env , install_module = False )
96
- wait_until_master_is_down .assert_not_called
97
89
98
-
99
- @patch ('sagemaker_tensorflow_container.training._wait_until_master_is_down' )
100
- @patch ('sagemaker_tensorflow_container.training._run_worker' )
101
- @patch ('sagemaker_tensorflow_container.training._run_ps' )
102
- def test_train_distributed_worker (run_ps ,
103
- run_worker ,
104
- wait_until_master_is_down ,
90
+ ps_tf_config = '{"cluster": {' \
91
+ '"master": ["host1:2222"], ' \
92
+ '"ps": ["host1:2223", "host2:2223"], ' \
93
+ '"worker": ["host2:2222"]}, ' \
94
+ '"environment": "cloud", ' \
95
+ '"task": {"index": 0, "type": "ps"}}'
96
+
97
+ run .assert_any_call ('s3://my/bucket' , 'script_name' ,
98
+ distributed_training_env .to_cmd_args (),
99
+ {'TF_CONFIG' : ps_tf_config })
100
+
101
+ master_tf_config = '{"cluster": {' \
102
+ '"master": ["host1:2222"], ' \
103
+ '"ps": ["host1:2223", "host2:2223"], ' \
104
+ '"worker": ["host2:2222"]}, ' \
105
+ '"environment": "cloud", ' \
106
+ '"task": {"index": 0, "type": "master"}}'
107
+
108
+ run .assert_called_with ('s3://my/bucket' , 'script_name' ,
109
+ distributed_training_env .to_cmd_args (),
110
+ {
111
+ 'TF_CONFIG' : master_tf_config })
112
+
113
+
114
+ @patch ('subprocess.check_call' )
115
+ @patch ('time.sleep' , MagicMock ())
116
+ @patch ('sagemaker_containers.beta.framework.entry_point.run' )
117
+ def test_train_distributed_worker (run ,
118
+ check_call ,
105
119
distributed_training_env ):
106
120
distributed_training_env .current_host = HOST2
121
+ check_call .side_effect = subprocess .CalledProcessError (returncode = 1 , cmd = [])
122
+
107
123
training .train (distributed_training_env )
108
- run_ps .assert_called_with (distributed_training_env )
109
- run_worker .assert_called_with (distributed_training_env , install_module = False )
110
- wait_until_master_is_down .assert_called_with (HOST1 )
111
124
125
+ ps_tf_config = '{"cluster": {' \
126
+ '"master": ["host1:2222"], ' \
127
+ '"ps": ["host1:2223", "host2:2223"], ' \
128
+ '"worker": ["host2:2222"]}, ' \
129
+ '"environment": "cloud", ' \
130
+ '"task": {"index": 1, "type": "ps"}}'
112
131
113
- @patch ('sagemaker_containers.beta.framework.modules.run_module' )
114
- def test_train_distributed_no_ps (run_module , distributed_training_env ):
132
+ run .assert_any_call ('s3://my/bucket' , 'script_name' ,
133
+ distributed_training_env .to_cmd_args (),
134
+ {'TF_CONFIG' : ps_tf_config })
135
+
136
+ master_tf_config = '{"cluster": {' \
137
+ '"master": ["host1:2222"], ' \
138
+ '"ps": ["host1:2223", "host2:2223"], ' \
139
+ '"worker": ["host2:2222"]}, ' \
140
+ '"environment": "cloud", ' \
141
+ '"task": {"index": 0, "type": "worker"}}'
142
+
143
+ run .assert_called_with ('s3://my/bucket' , 'script_name' ,
144
+ distributed_training_env .to_cmd_args (),
145
+ {
146
+ 'TF_CONFIG' : master_tf_config })
147
+
148
+
149
+ @patch ('sagemaker_containers.beta.framework.entry_point.run' )
150
+ def test_train_distributed_no_ps (run , distributed_training_env ):
115
151
distributed_training_env .additional_framework_parameters [
116
152
training .SAGEMAKER_PARAMETER_SERVER_ENABLED ] = False
117
153
distributed_training_env .current_host = HOST2
118
154
training .train (distributed_training_env )
119
- run_module .assert_called_with (MODULE_DIR , distributed_training_env .to_cmd_args (),
120
- distributed_training_env .to_env_vars (), MODULE_NAME )
155
+
156
+ run .assert_called_with (MODULE_DIR , MODULE_NAME , distributed_training_env .to_cmd_args (),
157
+ distributed_training_env .to_env_vars ())
121
158
122
159
123
160
@patch ('sagemaker_tensorflow_container.training._build_tf_config' )
@@ -131,61 +168,26 @@ def test_get_env_vars_with_tf_config(build_tf_config, distributed_training_env):
131
168
hosts = HOST_LIST , current_host = CURRENT_HOST , ps_task = True )
132
169
133
170
134
- @patch ('sagemaker_containers.beta.framework.modules.run_module ' )
171
+ @patch ('sagemaker_containers.beta.framework.entry_point.run ' )
135
172
@patch ('sagemaker_tensorflow_container.training._env_vars_with_tf_config' )
136
- def test_run_ps (env_vars_with_tf_config , run_module , distributed_training_env ):
137
- env_vars_with_tf_config .return_value = {}
138
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
173
+ def test_run_ps (env_vars_with_tf_config , run , distributed_training_env ):
139
174
training ._run_ps (distributed_training_env )
140
175
env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = True )
141
- run_module .assert_called_once_with (distributed_training_env .module_dir ,
142
- CMD_ARGS ,
143
- {},
144
- distributed_training_env .module_name ,
145
- wait = False )
146
-
147
176
148
- @patch ('sagemaker_containers.beta.framework.modules.write_env_vars' )
149
- @patch ('sagemaker_containers.beta.framework.modules.run' )
150
- @patch ('sagemaker_tensorflow_container.training._env_vars_with_tf_config' )
151
- def test_run_worker_no_install (get_env_vars_with_tf_config ,
152
- run ,
153
- write_env_vars ,
154
- distributed_training_env ):
155
- get_env_vars_with_tf_config .return_value = {}
156
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
157
- training ._run_worker (distributed_training_env , install_module = False )
158
- get_env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = False )
159
- write_env_vars .assert_called_once_with ({})
160
- run .assert_called_once_with (distributed_training_env .module_name ,
161
- CMD_ARGS ,
162
- {})
163
-
164
-
165
- @patch ('sagemaker_containers.beta.framework.modules.run_module' )
166
- @patch ('sagemaker_tensorflow_container.training._env_vars_with_tf_config' )
167
- def test_run_worker_install (get_env_vars_with_tf_config ,
168
- run_module ,
169
- distributed_training_env ):
170
- get_env_vars_with_tf_config .return_value = {}
171
- distributed_training_env .to_cmd_args .return_value = CMD_ARGS
172
- training ._run_worker (distributed_training_env , install_module = True )
173
- get_env_vars_with_tf_config .assert_called_once_with (distributed_training_env , ps_task = False )
174
- run_module .assert_called_once_with (distributed_training_env .module_dir ,
175
- CMD_ARGS ,
176
- {},
177
- distributed_training_env .module_name )
177
+ run .assert_called_once_with (distributed_training_env .module_dir ,
178
+ distributed_training_env .user_entry_point ,
179
+ distributed_training_env .to_cmd_args (), env_vars_with_tf_config ())
178
180
179
181
180
182
def test_build_tf_config ():
181
- assert training ._build_tf_config (HOST_LIST , HOST1 ) == \
182
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : MASTER_TASK }
183
+ assert training ._build_tf_config (HOST_LIST , HOST1 ) == \
184
+ {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : MASTER_TASK }
183
185
assert training ._build_tf_config (HOST_LIST , HOST1 , ps_task = True ) == \
184
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_1 }
185
- assert training ._build_tf_config (HOST_LIST , HOST2 ) == \
186
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : WORKER_TASK }
186
+ {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_1 }
187
+ assert training ._build_tf_config (HOST_LIST , HOST2 ) == \
188
+ {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : WORKER_TASK }
187
189
assert training ._build_tf_config (HOST_LIST , HOST2 , ps_task = True ) == \
188
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_2 }
190
+ {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_2 }
189
191
190
192
191
193
def test_build_tf_config_error ():
0 commit comments