Skip to content

Commit 1fab499

Browse files
authored
Update sagemaker containers (#119)
* Update sagemaker containers
1 parent 5913b17 commit 1fab499

File tree

4 files changed

+97
-101
lines changed

4 files changed

+97
-101
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(fname):
4949
'Programming Language :: Python :: 3.6',
5050
],
5151

52-
install_requires=['sagemaker-containers>=2.2.6', 'numpy', 'scipy', 'sklearn',
52+
install_requires=['sagemaker-containers>==2.3.1', 'numpy', 'scipy', 'sklearn',
5353
'pandas', 'Pillow', 'h5py'],
5454
extras_require={
5555
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',

src/sagemaker_tensorflow_container/training.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,12 @@ def _env_vars_with_tf_config(env, ps_task):
9999

100100
def _run_ps(env):
101101
env_vars = _env_vars_with_tf_config(env, ps_task=True)
102-
return framework.modules.run_module(
103-
env.module_dir, env.to_cmd_args(), env_vars, env.module_name, wait=False)
102+
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
104103

105104

106-
def _run_worker(env, install_module=False):
105+
def _run_worker(env):
107106
env_vars = _env_vars_with_tf_config(env, ps_task=False)
108-
if install_module:
109-
return framework.modules.run_module(
110-
env.module_dir, env.to_cmd_args(), env_vars, env.module_name)
111-
else:
112-
framework.modules.write_env_vars(env_vars)
113-
framework.modules.run(env.module_name, env.to_cmd_args(), env_vars)
107+
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
114108

115109

116110
def _wait_until_master_is_down(master):
@@ -139,14 +133,14 @@ def train(env):
139133
logger.info('Launching parameter server process')
140134
_run_ps(env)
141135
logger.info('Launching worker process')
142-
_run_worker(env, install_module=False)
136+
_run_worker(env)
143137

144138
if not _is_host_master(env.hosts, env.current_host):
145139
_wait_until_master_is_down(env.hosts[0])
146140

147141
else:
148-
framework.modules.run_module(env.module_dir, env.to_cmd_args(),
149-
env.to_env_vars(), env.module_name)
142+
framework.entry_point.run(env.module_dir, env.user_entry_point,
143+
env.to_cmd_args(), env.to_env_vars())
150144

151145

152146
def main():

test/unit/test_training.py

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import subprocess
1718

1819
from mock import MagicMock, patch
1920
import pytest
@@ -43,27 +44,24 @@
4344

4445
@pytest.fixture
4546
def distributed_training_env():
46-
env = MagicMock()
47-
48-
env.module_dir = MODULE_DIR
49-
env.module_name = MODULE_NAME
50-
env.hyperparameters = {}
51-
env.log_level = LOG_LEVEL
52-
env.hosts = HOST_LIST
53-
env.current_host = CURRENT_HOST
54-
env.additional_framework_parameters = {
55-
training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True
56-
}
57-
58-
return env
47+
return MagicMock(module_dir=MODULE_DIR,
48+
user_entry_point=MODULE_NAME,
49+
hyperparameters={},
50+
log_level=LOG_LEVEL,
51+
hosts=HOST_LIST,
52+
current_host=CURRENT_HOST,
53+
to_env_vars=lambda: {},
54+
additional_framework_parameters={
55+
training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True
56+
})
5957

6058

6159
@pytest.fixture
6260
def single_machine_training_env():
6361
env = MagicMock()
6462

6563
env.module_dir = MODULE_DIR
66-
env.module_name = MODULE_NAME
64+
env.user_entry_point = MODULE_NAME
6765
env.hyperparameters = {'model_dir': MODEL_DIR}
6866
env.log_level = LOG_LEVEL
6967

@@ -76,48 +74,87 @@ def test_is_host_master():
7674
assert training._is_host_master(HOST_LIST, 'somehost') is False
7775

7876

79-
@patch('sagemaker_containers.beta.framework.modules.run_module')
77+
@patch('sagemaker_containers.beta.framework.entry_point.run')
8078
def test_single_machine(run_module, single_machine_training_env):
8179
training.train(single_machine_training_env)
82-
run_module.assert_called_with(MODULE_DIR, single_machine_training_env.to_cmd_args(),
83-
single_machine_training_env.to_env_vars(), MODULE_NAME)
80+
run_module.assert_called_with(MODULE_DIR, MODULE_NAME,
81+
single_machine_training_env.to_cmd_args(),
82+
single_machine_training_env.to_env_vars())
8483

8584

86-
@patch('sagemaker_tensorflow_container.training._wait_until_master_is_down')
87-
@patch('sagemaker_tensorflow_container.training._run_worker')
88-
@patch('sagemaker_tensorflow_container.training._run_ps')
89-
def test_train_distributed_master(run_ps,
90-
run_worker,
91-
wait_until_master_is_down,
92-
distributed_training_env):
85+
@patch('sagemaker_containers.beta.framework.entry_point.run')
86+
@patch('time.sleep', MagicMock())
87+
def test_train_distributed_master(run, distributed_training_env):
9388
training.train(distributed_training_env)
94-
run_ps.assert_called_with(distributed_training_env)
95-
run_worker.assert_called_with(distributed_training_env, install_module=False)
96-
wait_until_master_is_down.assert_not_called
9789

98-
99-
@patch('sagemaker_tensorflow_container.training._wait_until_master_is_down')
100-
@patch('sagemaker_tensorflow_container.training._run_worker')
101-
@patch('sagemaker_tensorflow_container.training._run_ps')
102-
def test_train_distributed_worker(run_ps,
103-
run_worker,
104-
wait_until_master_is_down,
90+
ps_tf_config = '{"cluster": {' \
91+
'"master": ["host1:2222"], ' \
92+
'"ps": ["host1:2223", "host2:2223"], ' \
93+
'"worker": ["host2:2222"]}, ' \
94+
'"environment": "cloud", ' \
95+
'"task": {"index": 0, "type": "ps"}}'
96+
97+
run.assert_any_call('s3://my/bucket', 'script_name',
98+
distributed_training_env.to_cmd_args(),
99+
{'TF_CONFIG': ps_tf_config})
100+
101+
master_tf_config = '{"cluster": {' \
102+
'"master": ["host1:2222"], ' \
103+
'"ps": ["host1:2223", "host2:2223"], ' \
104+
'"worker": ["host2:2222"]}, ' \
105+
'"environment": "cloud", ' \
106+
'"task": {"index": 0, "type": "master"}}'
107+
108+
run.assert_called_with('s3://my/bucket', 'script_name',
109+
distributed_training_env.to_cmd_args(),
110+
{
111+
'TF_CONFIG': master_tf_config})
112+
113+
114+
@patch('subprocess.check_call')
115+
@patch('time.sleep', MagicMock())
116+
@patch('sagemaker_containers.beta.framework.entry_point.run')
117+
def test_train_distributed_worker(run,
118+
check_call,
105119
distributed_training_env):
106120
distributed_training_env.current_host = HOST2
121+
check_call.side_effect = subprocess.CalledProcessError(returncode=1, cmd=[])
122+
107123
training.train(distributed_training_env)
108-
run_ps.assert_called_with(distributed_training_env)
109-
run_worker.assert_called_with(distributed_training_env, install_module=False)
110-
wait_until_master_is_down.assert_called_with(HOST1)
111124

125+
ps_tf_config = '{"cluster": {' \
126+
'"master": ["host1:2222"], ' \
127+
'"ps": ["host1:2223", "host2:2223"], ' \
128+
'"worker": ["host2:2222"]}, ' \
129+
'"environment": "cloud", ' \
130+
'"task": {"index": 1, "type": "ps"}}'
112131

113-
@patch('sagemaker_containers.beta.framework.modules.run_module')
114-
def test_train_distributed_no_ps(run_module, distributed_training_env):
132+
run.assert_any_call('s3://my/bucket', 'script_name',
133+
distributed_training_env.to_cmd_args(),
134+
{'TF_CONFIG': ps_tf_config})
135+
136+
master_tf_config = '{"cluster": {' \
137+
'"master": ["host1:2222"], ' \
138+
'"ps": ["host1:2223", "host2:2223"], ' \
139+
'"worker": ["host2:2222"]}, ' \
140+
'"environment": "cloud", ' \
141+
'"task": {"index": 0, "type": "worker"}}'
142+
143+
run.assert_called_with('s3://my/bucket', 'script_name',
144+
distributed_training_env.to_cmd_args(),
145+
{
146+
'TF_CONFIG': master_tf_config})
147+
148+
149+
@patch('sagemaker_containers.beta.framework.entry_point.run')
150+
def test_train_distributed_no_ps(run, distributed_training_env):
115151
distributed_training_env.additional_framework_parameters[
116152
training.SAGEMAKER_PARAMETER_SERVER_ENABLED] = False
117153
distributed_training_env.current_host = HOST2
118154
training.train(distributed_training_env)
119-
run_module.assert_called_with(MODULE_DIR, distributed_training_env.to_cmd_args(),
120-
distributed_training_env.to_env_vars(), MODULE_NAME)
155+
156+
run.assert_called_with(MODULE_DIR, MODULE_NAME, distributed_training_env.to_cmd_args(),
157+
distributed_training_env.to_env_vars())
121158

122159

123160
@patch('sagemaker_tensorflow_container.training._build_tf_config')
@@ -131,61 +168,26 @@ def test_get_env_vars_with_tf_config(build_tf_config, distributed_training_env):
131168
hosts=HOST_LIST, current_host=CURRENT_HOST, ps_task=True)
132169

133170

134-
@patch('sagemaker_containers.beta.framework.modules.run_module')
171+
@patch('sagemaker_containers.beta.framework.entry_point.run')
135172
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
136-
def test_run_ps(env_vars_with_tf_config, run_module, distributed_training_env):
137-
env_vars_with_tf_config.return_value = {}
138-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
173+
def test_run_ps(env_vars_with_tf_config, run, distributed_training_env):
139174
training._run_ps(distributed_training_env)
140175
env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=True)
141-
run_module.assert_called_once_with(distributed_training_env.module_dir,
142-
CMD_ARGS,
143-
{},
144-
distributed_training_env.module_name,
145-
wait=False)
146-
147176

148-
@patch('sagemaker_containers.beta.framework.modules.write_env_vars')
149-
@patch('sagemaker_containers.beta.framework.modules.run')
150-
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
151-
def test_run_worker_no_install(get_env_vars_with_tf_config,
152-
run,
153-
write_env_vars,
154-
distributed_training_env):
155-
get_env_vars_with_tf_config.return_value = {}
156-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
157-
training._run_worker(distributed_training_env, install_module=False)
158-
get_env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=False)
159-
write_env_vars.assert_called_once_with({})
160-
run.assert_called_once_with(distributed_training_env.module_name,
161-
CMD_ARGS,
162-
{})
163-
164-
165-
@patch('sagemaker_containers.beta.framework.modules.run_module')
166-
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
167-
def test_run_worker_install(get_env_vars_with_tf_config,
168-
run_module,
169-
distributed_training_env):
170-
get_env_vars_with_tf_config.return_value = {}
171-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
172-
training._run_worker(distributed_training_env, install_module=True)
173-
get_env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=False)
174-
run_module.assert_called_once_with(distributed_training_env.module_dir,
175-
CMD_ARGS,
176-
{},
177-
distributed_training_env.module_name)
177+
run.assert_called_once_with(distributed_training_env.module_dir,
178+
distributed_training_env.user_entry_point,
179+
distributed_training_env.to_cmd_args(), env_vars_with_tf_config())
178180

179181

180182
def test_build_tf_config():
181-
assert training._build_tf_config(HOST_LIST, HOST1) ==\
182-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': MASTER_TASK}
183+
assert training._build_tf_config(HOST_LIST, HOST1) == \
184+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': MASTER_TASK}
183185
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == \
184-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_1}
185-
assert training._build_tf_config(HOST_LIST, HOST2) ==\
186-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': WORKER_TASK}
186+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_1}
187+
assert training._build_tf_config(HOST_LIST, HOST2) == \
188+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': WORKER_TASK}
187189
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == \
188-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_2}
190+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_2}
189191

190192

191193
def test_build_tf_config_error():

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ passenv =
5858
# Can be used to specify which tests to run, e.g.: tox -- -s
5959
commands =
6060
coverage run --rcfile .coveragerc_{envname} --source sagemaker_tensorflow_container -m py.test {posargs}
61-
{env:IGNORE_COVERAGE:} coverage report --fail-under=90 --include *sagemaker_tensorflow_container* --omit */tensorflow/tensorflow_serving/*
61+
{env:IGNORE_COVERAGE:} coverage report --fail-under=90 --include *sagemaker_tensorflow_container* --show-missing
6262
deps = .[test]
6363

6464
[testenv:flake8]

0 commit comments

Comments
 (0)