Skip to content

Commit 81a139d

Browse files
mvsuspEliza Zhang
authored andcommitted
Update sagemaker containers (aws#119)
* Update sagemaker containers
1 parent df8043c commit 81a139d

File tree

4 files changed

+115
-87
lines changed

4 files changed

+115
-87
lines changed

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,15 @@ def read_version():
5454
'Programming Language :: Python :: 3.6',
5555
],
5656

57+
<<<<<<< HEAD
5758
<<<<<<< HEAD
5859
install_requires=['sagemaker-containers>=2.4.6', 'numpy', 'scipy', 'sklearn',
5960
=======
6061
install_requires=['sagemaker-containers>=2.2.6', 'numpy', 'scipy', 'sklearn',
6162
>>>>>>> Add distributed training support (#98)
63+
=======
64+
install_requires=['sagemaker-containers>==2.3.1', 'numpy', 'scipy', 'sklearn',
65+
>>>>>>> Update sagemaker containers (#119)
6266
'pandas', 'Pillow', 'h5py'],
6367
extras_require={
6468
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',

src/sagemaker_tensorflow_container/training.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,22 @@ def _env_vars_with_tf_config(env, ps_task):
194194

195195
def _run_ps(env):
196196
env_vars = _env_vars_with_tf_config(env, ps_task=True)
197-
return framework.modules.run_module(
198-
env.module_dir, env.to_cmd_args(), env_vars, env.module_name, wait=False)
197+
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
199198

200199

201-
def _run_worker(env, install_module=False):
200+
def _run_worker(env):
202201
env_vars = _env_vars_with_tf_config(env, ps_task=False)
202+
<<<<<<< HEAD
203203
if install_module:
204204
return framework.modules.run_module(
205205
env.module_dir, env.to_cmd_args(), env_vars, env.module_name)
206206
else:
207207
framework.modules.write_env_vars(env_vars)
208208
framework.modules.run(env.module_name, env.to_cmd_args(), env_vars)
209209
>>>>>>> Add distributed training support (#98)
210+
=======
211+
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
212+
>>>>>>> Update sagemaker containers (#119)
210213

211214

212215
def _wait_until_master_is_down(master):
@@ -248,13 +251,18 @@ def train(env):
248251
logger.info('Launching parameter server process')
249252
_run_ps(env)
250253
logger.info('Launching worker process')
254+
<<<<<<< HEAD
251255
_run_worker(env, install_module=False)
252256
>>>>>>> Add distributed training support (#98)
257+
=======
258+
_run_worker(env)
259+
>>>>>>> Update sagemaker containers (#119)
253260

254261
if not _is_host_master(env.hosts, env.current_host):
255262
_wait_until_master_is_down(env.hosts[0])
256263

257264
else:
265+
<<<<<<< HEAD
258266
<<<<<<< HEAD
259267

260268
mpi_enabled = env.additional_framework_parameters.get('sagemaker_mpi_enabled')
@@ -304,6 +312,10 @@ def _model_dir_with_training_job(model_dir, job_name):
304312
framework.modules.run_module(env.module_dir, env.to_cmd_args(),
305313
env.to_env_vars(), env.module_name)
306314
>>>>>>> Add distributed training support (#98)
315+
=======
316+
framework.entry_point.run(env.module_dir, env.user_entry_point,
317+
env.to_cmd_args(), env.to_env_vars())
318+
>>>>>>> Update sagemaker containers (#119)
307319

308320

309321
def main():

test/unit/test_training.py

Lines changed: 91 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
320320
=======
321321
import json
322322
import os
323+
import subprocess
323324

324325
>>>>>>> Add distributed training support (#98)
325326
from mock import MagicMock, patch
@@ -350,27 +351,24 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
350351

351352
@pytest.fixture
352353
def distributed_training_env():
353-
env = MagicMock()
354-
355-
env.module_dir = MODULE_DIR
356-
env.module_name = MODULE_NAME
357-
env.hyperparameters = {}
358-
env.log_level = LOG_LEVEL
359-
env.hosts = HOST_LIST
360-
env.current_host = CURRENT_HOST
361-
env.additional_framework_parameters = {
362-
training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True
363-
}
364-
365-
return env
354+
return MagicMock(module_dir=MODULE_DIR,
355+
user_entry_point=MODULE_NAME,
356+
hyperparameters={},
357+
log_level=LOG_LEVEL,
358+
hosts=HOST_LIST,
359+
current_host=CURRENT_HOST,
360+
to_env_vars=lambda: {},
361+
additional_framework_parameters={
362+
training.SAGEMAKER_PARAMETER_SERVER_ENABLED: True
363+
})
366364

367365

368366
@pytest.fixture
369367
def single_machine_training_env():
370368
env = MagicMock()
371369

372370
env.module_dir = MODULE_DIR
373-
env.module_name = MODULE_NAME
371+
env.user_entry_point = MODULE_NAME
374372
env.hyperparameters = {'model_dir': MODEL_DIR}
375373
env.log_level = LOG_LEVEL
376374

@@ -383,51 +381,95 @@ def test_is_host_master():
383381
assert training._is_host_master(HOST_LIST, 'somehost') is False
384382

385383

386-
@patch('sagemaker_containers.beta.framework.modules.run_module')
384+
@patch('sagemaker_containers.beta.framework.entry_point.run')
387385
def test_single_machine(run_module, single_machine_training_env):
388386
training.train(single_machine_training_env)
387+
<<<<<<< HEAD
389388
run_module.assert_called_with(MODULE_DIR, single_machine_training_env.to_cmd_args(),
390389
single_machine_training_env.to_env_vars(), MODULE_NAME)
391390
<<<<<<< HEAD
392391
>>>>>>> Scriptmode single machine training implementation (#78)
393392
=======
393+
=======
394+
run_module.assert_called_with(MODULE_DIR, MODULE_NAME,
395+
single_machine_training_env.to_cmd_args(),
396+
single_machine_training_env.to_env_vars())
397+
>>>>>>> Update sagemaker containers (#119)
394398

395399

396-
@patch('sagemaker_tensorflow_container.training._wait_until_master_is_down')
397-
@patch('sagemaker_tensorflow_container.training._run_worker')
398-
@patch('sagemaker_tensorflow_container.training._run_ps')
399-
def test_train_distributed_master(run_ps,
400-
run_worker,
401-
wait_until_master_is_down,
402-
distributed_training_env):
400+
@patch('sagemaker_containers.beta.framework.entry_point.run')
401+
@patch('time.sleep', MagicMock())
402+
def test_train_distributed_master(run, distributed_training_env):
403403
training.train(distributed_training_env)
404-
run_ps.assert_called_with(distributed_training_env)
405-
run_worker.assert_called_with(distributed_training_env, install_module=False)
406-
wait_until_master_is_down.assert_not_called
407404

405+
ps_tf_config = '{"cluster": {' \
406+
'"master": ["host1:2222"], ' \
407+
'"ps": ["host1:2223", "host2:2223"], ' \
408+
'"worker": ["host2:2222"]}, ' \
409+
'"environment": "cloud", ' \
410+
'"task": {"index": 0, "type": "ps"}}'
408411

409-
@patch('sagemaker_tensorflow_container.training._wait_until_master_is_down')
410-
@patch('sagemaker_tensorflow_container.training._run_worker')
411-
@patch('sagemaker_tensorflow_container.training._run_ps')
412-
def test_train_distributed_worker(run_ps,
413-
run_worker,
414-
wait_until_master_is_down,
412+
run.assert_any_call('s3://my/bucket', 'script_name',
413+
distributed_training_env.to_cmd_args(),
414+
{'TF_CONFIG': ps_tf_config})
415+
416+
master_tf_config = '{"cluster": {' \
417+
'"master": ["host1:2222"], ' \
418+
'"ps": ["host1:2223", "host2:2223"], ' \
419+
'"worker": ["host2:2222"]}, ' \
420+
'"environment": "cloud", ' \
421+
'"task": {"index": 0, "type": "master"}}'
422+
423+
run.assert_called_with('s3://my/bucket', 'script_name',
424+
distributed_training_env.to_cmd_args(),
425+
{
426+
'TF_CONFIG': master_tf_config})
427+
428+
429+
@patch('subprocess.check_call')
430+
@patch('time.sleep', MagicMock())
431+
@patch('sagemaker_containers.beta.framework.entry_point.run')
432+
def test_train_distributed_worker(run,
433+
check_call,
415434
distributed_training_env):
416435
distributed_training_env.current_host = HOST2
436+
check_call.side_effect = subprocess.CalledProcessError(returncode=1, cmd=[])
437+
417438
training.train(distributed_training_env)
418-
run_ps.assert_called_with(distributed_training_env)
419-
run_worker.assert_called_with(distributed_training_env, install_module=False)
420-
wait_until_master_is_down.assert_called_with(HOST1)
421439

440+
ps_tf_config = '{"cluster": {' \
441+
'"master": ["host1:2222"], ' \
442+
'"ps": ["host1:2223", "host2:2223"], ' \
443+
'"worker": ["host2:2222"]}, ' \
444+
'"environment": "cloud", ' \
445+
'"task": {"index": 1, "type": "ps"}}'
446+
447+
run.assert_any_call('s3://my/bucket', 'script_name',
448+
distributed_training_env.to_cmd_args(),
449+
{'TF_CONFIG': ps_tf_config})
450+
451+
master_tf_config = '{"cluster": {' \
452+
'"master": ["host1:2222"], ' \
453+
'"ps": ["host1:2223", "host2:2223"], ' \
454+
'"worker": ["host2:2222"]}, ' \
455+
'"environment": "cloud", ' \
456+
'"task": {"index": 0, "type": "worker"}}'
422457

423-
@patch('sagemaker_containers.beta.framework.modules.run_module')
424-
def test_train_distributed_no_ps(run_module, distributed_training_env):
458+
run.assert_called_with('s3://my/bucket', 'script_name',
459+
distributed_training_env.to_cmd_args(),
460+
{
461+
'TF_CONFIG': master_tf_config})
462+
463+
464+
@patch('sagemaker_containers.beta.framework.entry_point.run')
465+
def test_train_distributed_no_ps(run, distributed_training_env):
425466
distributed_training_env.additional_framework_parameters[
426467
training.SAGEMAKER_PARAMETER_SERVER_ENABLED] = False
427468
distributed_training_env.current_host = HOST2
428469
training.train(distributed_training_env)
429-
run_module.assert_called_with(MODULE_DIR, distributed_training_env.to_cmd_args(),
430-
distributed_training_env.to_env_vars(), MODULE_NAME)
470+
471+
run.assert_called_with(MODULE_DIR, MODULE_NAME, distributed_training_env.to_cmd_args(),
472+
distributed_training_env.to_env_vars())
431473

432474

433475
@patch('sagemaker_tensorflow_container.training._build_tf_config')
@@ -441,61 +483,26 @@ def test_get_env_vars_with_tf_config(build_tf_config, distributed_training_env):
441483
hosts=HOST_LIST, current_host=CURRENT_HOST, ps_task=True)
442484

443485

444-
@patch('sagemaker_containers.beta.framework.modules.run_module')
486+
@patch('sagemaker_containers.beta.framework.entry_point.run')
445487
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
446-
def test_run_ps(env_vars_with_tf_config, run_module, distributed_training_env):
447-
env_vars_with_tf_config.return_value = {}
448-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
488+
def test_run_ps(env_vars_with_tf_config, run, distributed_training_env):
449489
training._run_ps(distributed_training_env)
450490
env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=True)
451-
run_module.assert_called_once_with(distributed_training_env.module_dir,
452-
CMD_ARGS,
453-
{},
454-
distributed_training_env.module_name,
455-
wait=False)
456-
457491

458-
@patch('sagemaker_containers.beta.framework.modules.write_env_vars')
459-
@patch('sagemaker_containers.beta.framework.modules.run')
460-
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
461-
def test_run_worker_no_install(get_env_vars_with_tf_config,
462-
run,
463-
write_env_vars,
464-
distributed_training_env):
465-
get_env_vars_with_tf_config.return_value = {}
466-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
467-
training._run_worker(distributed_training_env, install_module=False)
468-
get_env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=False)
469-
write_env_vars.assert_called_once_with({})
470-
run.assert_called_once_with(distributed_training_env.module_name,
471-
CMD_ARGS,
472-
{})
473-
474-
475-
@patch('sagemaker_containers.beta.framework.modules.run_module')
476-
@patch('sagemaker_tensorflow_container.training._env_vars_with_tf_config')
477-
def test_run_worker_install(get_env_vars_with_tf_config,
478-
run_module,
479-
distributed_training_env):
480-
get_env_vars_with_tf_config.return_value = {}
481-
distributed_training_env.to_cmd_args.return_value = CMD_ARGS
482-
training._run_worker(distributed_training_env, install_module=True)
483-
get_env_vars_with_tf_config.assert_called_once_with(distributed_training_env, ps_task=False)
484-
run_module.assert_called_once_with(distributed_training_env.module_dir,
485-
CMD_ARGS,
486-
{},
487-
distributed_training_env.module_name)
492+
run.assert_called_once_with(distributed_training_env.module_dir,
493+
distributed_training_env.user_entry_point,
494+
distributed_training_env.to_cmd_args(), env_vars_with_tf_config())
488495

489496

490497
def test_build_tf_config():
491-
assert training._build_tf_config(HOST_LIST, HOST1) ==\
492-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': MASTER_TASK}
498+
assert training._build_tf_config(HOST_LIST, HOST1) == \
499+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': MASTER_TASK}
493500
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == \
494-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_1}
495-
assert training._build_tf_config(HOST_LIST, HOST2) ==\
496-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': WORKER_TASK}
501+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_1}
502+
assert training._build_tf_config(HOST_LIST, HOST2) == \
503+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': WORKER_TASK}
497504
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == \
498-
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_2}
505+
{'cluster': CLUSTER_WITH_PS, 'environment': 'cloud', 'task': PS_TASK_2}
499506

500507

501508
def test_build_tf_config_error():

tox.ini

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ passenv =
6060
# Can be used to specify which tests to run, e.g.: tox -- -s
6161
commands =
6262
coverage run --rcfile .coveragerc_{envname} --source sagemaker_tensorflow_container -m py.test {posargs}
63+
<<<<<<< HEAD
6364
{env:IGNORE_COVERAGE:} coverage report --include *sagemaker_tensorflow_container* --show-missing
6465
deps = sagemaker-containers
6566
extras = test
67+
=======
68+
{env:IGNORE_COVERAGE:} coverage report --fail-under=90 --include *sagemaker_tensorflow_container* --show-missing
69+
deps = .[test]
70+
>>>>>>> Update sagemaker containers (#119)
6671

6772
[testenv:flake8]
6873
basepython = python

0 commit comments

Comments
 (0)