@@ -96,7 +96,8 @@ def test_train_distributed_master(run, distributed_training_env):
96
96
97
97
run .assert_any_call ('s3://my/bucket' , 'script_name' ,
98
98
distributed_training_env .to_cmd_args (),
99
- {'TF_CONFIG' : ps_tf_config })
99
+ {'TF_CONFIG' : ps_tf_config },
100
+ wait = False )
100
101
101
102
master_tf_config = '{"cluster": {' \
102
103
'"master": ["host1:2222"], ' \
@@ -107,8 +108,7 @@ def test_train_distributed_master(run, distributed_training_env):
107
108
108
109
run .assert_called_with ('s3://my/bucket' , 'script_name' ,
109
110
distributed_training_env .to_cmd_args (),
110
- {
111
- 'TF_CONFIG' : master_tf_config })
111
+ {'TF_CONFIG' : master_tf_config })
112
112
113
113
114
114
@patch ('subprocess.check_call' )
@@ -131,7 +131,7 @@ def test_train_distributed_worker(run,
131
131
132
132
run .assert_any_call ('s3://my/bucket' , 'script_name' ,
133
133
distributed_training_env .to_cmd_args (),
134
- {'TF_CONFIG' : ps_tf_config })
134
+ {'TF_CONFIG' : ps_tf_config }, wait = False )
135
135
136
136
master_tf_config = '{"cluster": {' \
137
137
'"master": ["host1:2222"], ' \
@@ -176,18 +176,30 @@ def test_run_ps(env_vars_with_tf_config, run, distributed_training_env):
176
176
177
177
run .assert_called_once_with (distributed_training_env .module_dir ,
178
178
distributed_training_env .user_entry_point ,
179
- distributed_training_env .to_cmd_args (), env_vars_with_tf_config ())
179
+ distributed_training_env .to_cmd_args (), env_vars_with_tf_config (),
180
+ wait = False )
180
181
181
182
182
183
def test_build_tf_config ():
183
- assert training ._build_tf_config (HOST_LIST , HOST1 ) == \
184
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : MASTER_TASK }
185
- assert training ._build_tf_config (HOST_LIST , HOST1 , ps_task = True ) == \
186
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_1 }
187
- assert training ._build_tf_config (HOST_LIST , HOST2 ) == \
188
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : WORKER_TASK }
189
- assert training ._build_tf_config (HOST_LIST , HOST2 , ps_task = True ) == \
190
- {'cluster' : CLUSTER_WITH_PS , 'environment' : 'cloud' , 'task' : PS_TASK_2 }
184
+ assert training ._build_tf_config (HOST_LIST , HOST1 ) == {
185
+ 'cluster' : CLUSTER_WITH_PS ,
186
+ 'environment' : 'cloud' ,
187
+ 'task' : MASTER_TASK
188
+ }
189
+ assert training ._build_tf_config (HOST_LIST , HOST1 , ps_task = True ) == {
190
+ 'cluster' : CLUSTER_WITH_PS ,
191
+ 'environment' : 'cloud' ,
192
+ 'task' : PS_TASK_1
193
+ }
194
+ assert training ._build_tf_config (HOST_LIST , HOST2 ) == {
195
+ 'cluster' : CLUSTER_WITH_PS ,
196
+ 'environment' : 'cloud' ,
197
+ 'task' : WORKER_TASK
198
+ }
199
+ assert training ._build_tf_config (HOST_LIST , HOST2 , ps_task = True ) == {
200
+ 'cluster' : CLUSTER_WITH_PS ,
201
+ 'environment' : 'cloud' ,
202
+ 'task' : PS_TASK_2 }
191
203
192
204
193
205
def test_build_tf_config_error ():
0 commit comments