Skip to content

Commit c4abcae

Browse files
authored
Set parameter process waiting to False (#120)
* Add wait False to run-ps
1 parent 1fab499 commit c4abcae

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def _env_vars_with_tf_config(env, ps_task):
9999

100100
def _run_ps(env):
101101
env_vars = _env_vars_with_tf_config(env, ps_task=True)
102-
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
102+
framework.entry_point.run(env.module_dir, env.user_entry_point,
103+
env.to_cmd_args(), env_vars, wait=False)
103104

104105

105106
def _run_worker(env):

test/unit/test_training.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def test_train_distributed_master(run, distributed_training_env):
9696

9797
run.assert_any_call('s3://my/bucket', 'script_name',
9898
distributed_training_env.to_cmd_args(),
99-
{'TF_CONFIG': ps_tf_config})
99+
{'TF_CONFIG': ps_tf_config},
100+
wait=False)
100101

101102
master_tf_config = '{"cluster": {' \
102103
'"master": ["host1:2222"], ' \
@@ -107,8 +108,7 @@ def test_train_distributed_master(run, distributed_training_env):
107108

108109
run.assert_called_with('s3://my/bucket', 'script_name',
109110
distributed_training_env.to_cmd_args(),
110-
{
111-
'TF_CONFIG': master_tf_config})
111+
{'TF_CONFIG': master_tf_config})
112112

113113

114114
@patch('subprocess.check_call')
@@ -131,7 +131,7 @@ def test_train_distributed_worker(run,
131131

132132
run.assert_any_call('s3://my/bucket', 'script_name',
133133
distributed_training_env.to_cmd_args(),
134-
{'TF_CONFIG': ps_tf_config})
134+
{'TF_CONFIG': ps_tf_config}, wait=False)
135135

136136
master_tf_config = '{"cluster": {' \
137137
'"master": ["host1:2222"], ' \
@@ -176,18 +176,30 @@ def test_run_ps(env_vars_with_tf_config, run, distributed_training_env):
176176

177177
run.assert_called_once_with(distributed_training_env.module_dir,
178178
distributed_training_env.user_entry_point,
179-
distributed_training_env.to_cmd_args(), env_vars_with_tf_config())
179+
distributed_training_env.to_cmd_args(), env_vars_with_tf_config(),
180+
wait=False)
180181

181182

182183
def test_build_tf_config():
183-
assert training._build_tf_config(HOST_LIST, HOST1) == \
184-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': MASTER_TASK}
185-
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == \
186-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_1}
187-
assert training._build_tf_config(HOST_LIST, HOST2) == \
188-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': WORKER_TASK}
189-
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == \
190-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_2}
184+
assert training._build_tf_config(HOST_LIST, HOST1) == {
185+
'cluster': CLUSTER_WITH_PS,
186+
'environment': 'cloud',
187+
'task': MASTER_TASK
188+
}
189+
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == {
190+
'cluster': CLUSTER_WITH_PS,
191+
'environment': 'cloud',
192+
'task': PS_TASK_1
193+
}
194+
assert training._build_tf_config(HOST_LIST, HOST2) == {
195+
'cluster': CLUSTER_WITH_PS,
196+
'environment': 'cloud',
197+
'task': WORKER_TASK
198+
}
199+
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == {
200+
'cluster': CLUSTER_WITH_PS,
201+
'environment': 'cloud',
202+
'task': PS_TASK_2}
191203

192204

193205
def test_build_tf_config_error():

0 commit comments

Comments
 (0)