Skip to content

Commit 962f15b

Browse files
authored
Create parameter server in different thread (#129)
* Create parameter server in different thread * Fixing some integ tests
1 parent 49a0547 commit 962f15b

File tree

7 files changed

+114
-159
lines changed

7 files changed

+114
-159
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,6 @@ def read(fname):
5353
'pandas', 'Pillow', 'h5py'],
5454
extras_require={
5555
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',
56-
'sagemaker', 'tensorflow', 'docker-compose']
56+
'sagemaker>=1.15.2', 'tensorflow', 'docker-compose']
5757
},
5858
)

src/sagemaker_tensorflow_container/training.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717
import logging
1818
import os
1919
import subprocess
20+
import threading
2021
import time
2122

2223
import sagemaker_containers.beta.framework as framework
24+
import tensorflow as tf
2325

24-
import sagemaker_tensorflow_container.s3_utils as s3_utils
25-
26+
from sagemaker_tensorflow_container import s3_utils
2627

2728
logger = logging.getLogger(__name__)
2829

29-
3030
SAGEMAKER_PARAMETER_SERVER_ENABLED = 'sagemaker_parameter_server_enabled'
3131

3232

@@ -88,30 +88,21 @@ def host_addresses(hosts, port=2222):
8888
return tf_config
8989

9090

91-
def _env_vars_with_tf_config(env, ps_task):
91+
def _run_ps(env, cluster):
92+
logger.info('Running distributed training job with parameter servers')
93+
94+
cluster_spec = tf.train.ClusterSpec(cluster)
95+
task_index = env.hosts.index(env.current_host)
96+
97+
server = tf.train.Server(cluster_spec, job_name='ps', task_index=task_index)
98+
99+
threading.Thread(target=lambda: server.join()).start()
100+
101+
102+
def _run_worker(env, tf_config):
92103
env_vars = env.to_env_vars()
93-
env_vars['TF_CONFIG'] = json.dumps(_build_tf_config(
94-
hosts=env.hosts,
95-
current_host=env.current_host,
96-
ps_task=ps_task))
97-
return env_vars
98-
99-
100-
def _run_ps(env):
101-
env_vars = _env_vars_with_tf_config(env, ps_task=True)
102-
# Parameter server processes should always run on CPU. Sets CUDA_VISIBLE_DEVICES to '-1' forces
103-
# TensorFlow to use CPU.
104-
env_vars['CUDA_VISIBLE_DEVICES'] = json.dumps(-1)
105-
framework.entry_point.run(env.module_dir, env.user_entry_point,
106-
env.to_cmd_args(), env_vars, wait=False)
107-
108-
109-
def _run_worker(env):
110-
# when _run_ps is called CUDA_VISIBLE_DEVICES is set with os.environ.
111-
# We need to unset it so the worker process can use the GPUs.
112-
if os.environ.get('CUDA_VISIBLE_DEVICES'):
113-
del os.environ['CUDA_VISIBLE_DEVICES']
114-
env_vars = _env_vars_with_tf_config(env, ps_task=False)
104+
env_vars['TF_CONFIG'] = json.dumps(tf_config)
105+
115106
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
116107

117108

@@ -137,11 +128,13 @@ def train(env):
137128
SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
138129
if len(env.hosts) > 1 and parameter_server_enabled:
139130

131+
tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)
132+
140133
logger.info('Running distributed training job with parameter servers')
141134
logger.info('Launching parameter server process')
142-
_run_ps(env)
135+
_run_ps(env, tf_config['cluster'])
143136
logger.info('Launching worker process')
144-
_run_worker(env)
137+
_run_worker(env, tf_config)
145138

146139
if not _is_host_master(env.hosts, env.current_host):
147140
_wait_until_master_is_down(env.hosts[0])

test/integration/local/test_keras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_keras_training(sagemaker_local_session, docker_image, tmpdir):
3333
role='SageMakerRole',
3434
train_instance_count=1,
3535
train_instance_type='local',
36+
image_name=docker_image,
3637
sagemaker_session=sagemaker_local_session,
3738
model_dir='/opt/ml/model',
3839
output_path=output_path,
@@ -41,7 +42,8 @@ def test_keras_training(sagemaker_local_session, docker_image, tmpdir):
4142

4243
estimator.fit()
4344

44-
model = serving.Model(model_data=output_path, role='SageMakerRole',
45+
model = serving.Model(model_data=output_path,
46+
role='SageMakerRole',
4547
framework_version='1.11.0',
4648
sagemaker_session=sagemaker_local_session)
4749

test/integration/local/test_training.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
import tarfile
1717

1818
import pytest
19-
from sagemaker.estimator import Framework
2019
from sagemaker.tensorflow import TensorFlow
2120

2221
from test.integration.docker_utils import Container
2322

24-
2523
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
2624
TF_CHECKPOINT_FILES = ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']
2725

@@ -94,34 +92,27 @@ def test_distributed_training_cpu_ps(sagemaker_local_session, docker_image, tmpd
9492
_assert_files_exist_in_tar(output_path, TF_CHECKPOINT_FILES)
9593

9694

97-
class ScriptModeTensorFlow(Framework):
98-
"""This class is temporary until the final version of Script Mode is released.
99-
"""
100-
101-
__framework_name__ = "tensorflow-scriptmode-beta"
102-
103-
create_model = TensorFlow.create_model
104-
105-
def __init__(self, py_version='py', **kwargs):
106-
self.requirements_file = None
107-
self.py_version = py_version
108-
self.framework_version = 'some version'
109-
super(ScriptModeTensorFlow, self).__init__(**kwargs)
110-
111-
112-
def run_tf_training(script, instance_type, instance_count,
95+
def run_tf_training(script,
96+
instance_type,
97+
instance_count,
11398
sagemaker_local_session,
11499
docker_image, training_data_path, output_path=None,
115-
hyperparameters={}):
116-
estimator = ScriptModeTensorFlow(entry_point=script,
117-
role='SageMakerRole',
118-
train_instance_count=instance_count,
119-
train_instance_type=instance_type,
120-
sagemaker_session=sagemaker_local_session,
121-
image_name=docker_image,
122-
output_path=output_path,
123-
hyperparameters=hyperparameters,
124-
base_job_name='test-tf')
100+
hyperparameters=None):
101+
102+
hyperparameters = hyperparameters or {}
103+
104+
estimator = TensorFlow(entry_point=script,
105+
role='SageMakerRole',
106+
train_instance_count=instance_count,
107+
train_instance_type=instance_type,
108+
sagemaker_session=sagemaker_local_session,
109+
image_name=docker_image,
110+
model_dir='/opt/ml/model',
111+
output_path=output_path,
112+
hyperparameters=hyperparameters,
113+
base_job_name='test-tf',
114+
framework_version='1.11.0',
115+
py_version='py3')
125116

126117
estimator.fit(training_data_path)
127118

test/integration/sagemaker/test_mnist.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def test_mnist(sagemaker_session, ecr_image, instance_type):
2626
script = os.path.join(resource_path, 'mnist', 'mnist.py')
2727
estimator = TensorFlow(entry_point=script,
2828
role='SageMakerRole',
29-
training_steps=1,
30-
evaluation_steps=1,
31-
train_instance_count=1,
3229
train_instance_type=instance_type,
30+
train_instance_count=1,
3331
sagemaker_session=sagemaker_session,
3432
image_name=ecr_image,
33+
framework_version='1.11.0',
34+
py_version='py3',
3535
base_job_name='test-sagemaker-mnist')
3636
inputs = estimator.sagemaker_session.upload_data(
3737
path=os.path.join(resource_path, 'mnist', 'data'),
@@ -46,44 +46,41 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
4646
script = os.path.join(resource_path, 'mnist', 'distributed_mnist.py')
4747
estimator = TensorFlow(entry_point=script,
4848
role='SageMakerRole',
49-
training_steps=1,
50-
evaluation_steps=1,
5149
train_instance_count=2,
5250
train_instance_type=instance_type,
5351
sagemaker_session=sagemaker_session,
5452
image_name=ecr_image,
53+
framework_version='1.11.0',
54+
py_version='py3',
5555
base_job_name='test-tf-sm-distributed-mnist')
5656
inputs = estimator.sagemaker_session.upload_data(
5757
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
5858
key_prefix='scriptmode/mnist-distributed')
5959
estimator.fit(inputs)
60-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'graph.pbtxt'))
61-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'model.ckpt-0.index'))
62-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'model.ckpt-0.meta'))
60+
model_s3_url = estimator.create_model().model_data
61+
_assert_s3_file_exists(model_s3_url)
6362

6463

6564
def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type):
6665
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
6766
script = os.path.join(resource_path, 'mnist', 'distributed_mnist.py')
6867
estimator = TensorFlow(entry_point=script,
6968
role='SageMakerRole',
70-
# training_steps and evaluation_steps are legacy parameters from
71-
# framework mode. These number are not used in the training job.
72-
training_steps=1,
73-
evaluation_steps=1,
7469
hyperparameters={SAGEMAKER_PARAMETER_SERVER_ENABLED: True},
7570
train_instance_count=2,
7671
train_instance_type=instance_type,
7772
sagemaker_session=sagemaker_session,
7873
image_name=ecr_image,
74+
framework_version='1.11.0',
75+
py_version='py3',
7976
base_job_name='test-tf-sm-distributed-mnist')
8077
inputs = estimator.sagemaker_session.upload_data(
8178
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
8279
key_prefix='scriptmode/mnist-distributed')
8380
estimator.fit(inputs)
84-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'graph.pbtxt'))
85-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'model.ckpt-0.index'))
86-
_assert_s3_file_exists(os.path.join(estimator.checkpoint_path, 'model.ckpt-0.meta'))
81+
_assert_s3_file_exists(os.path.join(estimator.model_dir, 'graph.pbtxt'))
82+
_assert_s3_file_exists(os.path.join(estimator.model_dir, 'model.ckpt-0.index'))
83+
_assert_s3_file_exists(os.path.join(estimator.model_dir, 'model.ckpt-0.meta'))
8784

8885

8986
def _assert_s3_file_exists(s3_url):

test/resources/mnist/distributed_mnist.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tensorflow.python.platform import tf_logging
1212
import logging as _logging
1313
import sys as _sys
14+
import json
1415

1516

1617
def cnn_model_fn(features, labels, mode):
@@ -122,9 +123,8 @@ def _parse_args():
122123
parser.add_argument('--epochs', type=int, default=1)
123124
# Data, model, and output directories
124125
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
125-
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
126+
parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
126127
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
127-
parser.add_argument('--checkpoint_path', type=str, default=os.environ['SM_MODEL_DIR'])
128128

129129
return parser.parse_known_args()
130130

@@ -140,8 +140,12 @@ def _parse_args():
140140
eval_data, eval_labels = _load_testing_data(args.train)
141141

142142
# Create the Estimator
143+
if json.loads(os.environ['SM_TRAINING_ENV'])['additional_framework_parameters'].get('sagemaker_parameter_server_enabled'):
144+
model_dir = args.model_dir
145+
else:
146+
model_dir = os.environ['SM_MODEL_DIR']
143147
mnist_classifier = tf.estimator.Estimator(
144-
model_fn=cnn_model_fn, model_dir=args.checkpoint_path)
148+
model_fn=cnn_model_fn, model_dir=model_dir)
145149

146150
# Set up logging for predictions
147151
# Log the values in the "Softmax" tensor with label "probabilities"

0 commit comments

Comments
 (0)