Skip to content

Create parameter server in different thread #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def read(fname):
'pandas', 'Pillow', 'h5py'],
extras_require={
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',
'sagemaker', 'tensorflow', 'docker-compose']
'sagemaker>=1.15.2', 'tensorflow', 'docker-compose']
},
)
49 changes: 21 additions & 28 deletions src/sagemaker_tensorflow_container/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import logging
import os
import subprocess
import threading
import time

import sagemaker_containers.beta.framework as framework
import tensorflow as tf

import sagemaker_tensorflow_container.s3_utils as s3_utils

from sagemaker_tensorflow_container import s3_utils

logger = logging.getLogger(__name__)


SAGEMAKER_PARAMETER_SERVER_ENABLED = 'sagemaker_parameter_server_enabled'


Expand Down Expand Up @@ -88,30 +88,21 @@ def host_addresses(hosts, port=2222):
return tf_config


def _env_vars_with_tf_config(env, ps_task):
def _run_ps(env, cluster):
logger.info('Running distributed training job with parameter servers')

cluster_spec = tf.train.ClusterSpec(cluster)
task_index = env.hosts.index(env.current_host)

server = tf.train.Server(cluster_spec, job_name='ps', task_index=task_index)

threading.Thread(target=lambda: server.join()).start()


def _run_worker(env, tf_config):
env_vars = env.to_env_vars()
env_vars['TF_CONFIG'] = json.dumps(_build_tf_config(
hosts=env.hosts,
current_host=env.current_host,
ps_task=ps_task))
return env_vars


def _run_ps(env):
env_vars = _env_vars_with_tf_config(env, ps_task=True)
# Parameter server processes should always run on CPU. Sets CUDA_VISIBLE_DEVICES to '-1' forces
# TensorFlow to use CPU.
env_vars['CUDA_VISIBLE_DEVICES'] = json.dumps(-1)
framework.entry_point.run(env.module_dir, env.user_entry_point,
env.to_cmd_args(), env_vars, wait=False)


def _run_worker(env):
# when _run_ps is called CUDA_VISIBLE_DEVICES is set with os.environ.
# We need to unset it so the worker process can use the GPUs.
if os.environ.get('CUDA_VISIBLE_DEVICES'):
del os.environ['CUDA_VISIBLE_DEVICES']
env_vars = _env_vars_with_tf_config(env, ps_task=False)
env_vars['TF_CONFIG'] = json.dumps(tf_config)

framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)


Expand All @@ -137,11 +128,13 @@ def train(env):
SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
if len(env.hosts) > 1 and parameter_server_enabled:

tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)

logger.info('Running distributed training job with parameter servers')
logger.info('Launching parameter server process')
_run_ps(env)
_run_ps(env, tf_config['cluster'])
logger.info('Launching worker process')
_run_worker(env)
_run_worker(env, tf_config)

if not _is_host_master(env.hosts, env.current_host):
_wait_until_master_is_down(env.hosts[0])
Expand Down
4 changes: 3 additions & 1 deletion test/integration/local/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_keras_training(sagemaker_local_session, docker_image, tmpdir):
role='SageMakerRole',
train_instance_count=1,
train_instance_type='local',
image_name=docker_image,
sagemaker_session=sagemaker_local_session,
model_dir='/opt/ml/model',
output_path=output_path,
Expand All @@ -41,7 +42,8 @@ def test_keras_training(sagemaker_local_session, docker_image, tmpdir):

estimator.fit()

model = serving.Model(model_data=output_path, role='SageMakerRole',
model = serving.Model(model_data=output_path,
role='SageMakerRole',
framework_version='1.11.0',
sagemaker_session=sagemaker_local_session)

Expand Down
47 changes: 19 additions & 28 deletions test/integration/local/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
import tarfile

import pytest
from sagemaker.estimator import Framework
from sagemaker.tensorflow import TensorFlow

from test.integration.docker_utils import Container


RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
TF_CHECKPOINT_FILES = ['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta']

Expand Down Expand Up @@ -94,34 +92,27 @@ def test_distributed_training_cpu_ps(sagemaker_local_session, docker_image, tmpd
_assert_files_exist_in_tar(output_path, TF_CHECKPOINT_FILES)


class ScriptModeTensorFlow(Framework):
"""This class is temporary until the final version of Script Mode is released.
"""

__framework_name__ = "tensorflow-scriptmode-beta"

create_model = TensorFlow.create_model

def __init__(self, py_version='py', **kwargs):
self.requirements_file = None
self.py_version = py_version
self.framework_version = 'some version'
super(ScriptModeTensorFlow, self).__init__(**kwargs)


def run_tf_training(script, instance_type, instance_count,
def run_tf_training(script,
instance_type,
instance_count,
sagemaker_local_session,
docker_image, training_data_path, output_path=None,
hyperparameters={}):
estimator = ScriptModeTensorFlow(entry_point=script,
role='SageMakerRole',
train_instance_count=instance_count,
train_instance_type=instance_type,
sagemaker_session=sagemaker_local_session,
image_name=docker_image,
output_path=output_path,
hyperparameters=hyperparameters,
base_job_name='test-tf')
hyperparameters=None):

hyperparameters = hyperparameters or {}

estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
train_instance_count=instance_count,
train_instance_type=instance_type,
sagemaker_session=sagemaker_local_session,
image_name=docker_image,
model_dir='/opt/ml/model',
output_path=output_path,
hyperparameters=hyperparameters,
base_job_name='test-tf',
framework_version='1.11.0',
py_version='py3')

estimator.fit(training_data_path)

Expand Down
16 changes: 7 additions & 9 deletions test/integration/sagemaker/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def test_mnist(sagemaker_session, ecr_image, instance_type):
script = os.path.join(resource_path, 'mnist', 'mnist.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
training_steps=1,
evaluation_steps=1,
train_instance_count=1,
train_instance_type=instance_type,
train_instance_count=1,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version='1.11.0',
py_version='py3',
base_job_name='test-sagemaker-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(resource_path, 'mnist', 'data'),
Expand All @@ -46,12 +46,12 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
script = os.path.join(resource_path, 'mnist', 'distributed_mnist.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
training_steps=1,
evaluation_steps=1,
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version='1.11.0',
py_version='py3',
base_job_name='test-tf-sm-distributed-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
Expand All @@ -67,15 +67,13 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type):
script = os.path.join(resource_path, 'mnist', 'distributed_mnist.py')
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
# training_steps and evaluation_steps are legacy parameters from
# framework mode. These number are not used in the training job.
training_steps=1,
evaluation_steps=1,
hyperparameters={SAGEMAKER_PARAMETER_SERVER_ENABLED: True},
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version='1.11.0',
py_version='py3',
base_job_name='test-tf-sm-distributed-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
Expand Down
3 changes: 1 addition & 2 deletions test/resources/mnist/distributed_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def _parse_args():
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
parser.add_argument('--checkpoint_path', type=str, default=os.environ['SM_MODEL_DIR'])

return parser.parse_known_args()

Expand All @@ -141,7 +140,7 @@ def _parse_args():

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir=args.checkpoint_path)
model_fn=cnn_model_fn, model_dir=args.model_dir)

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
Expand Down
Loading