Skip to content

Commit 005ce70

Browse files
icywang86ruiEliza Zhang
authored andcommitted
Create parameter server in different thread (aws#129)
* Create parameter server in different thread * Fixing some integ tests
1 parent bb3f360 commit 005ce70

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
import multiprocessing
1818
import os
1919
import subprocess
20+
import threading
2021
import time
2122

2223
import sagemaker_containers.beta.framework as framework
2324
import tensorflow as tf
2425

2526
from sagemaker_tensorflow_container import s3_utils
2627

27-
2828
logger = logging.getLogger(__name__)
2929

3030
SAGEMAKER_PARAMETER_SERVER_ENABLED = 'sagemaker_parameter_server_enabled'

test/integration/local/test_training.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import tarfile
1717

1818
import pytest
19-
from sagemaker.estimator import Framework
2019
from sagemaker.tensorflow import TensorFlow
2120

2221
from test.integration.utils import processor, py_version # noqa: F401

test/resources/mnist/distributed_mnist.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tensorflow.python.platform import tf_logging
1212
import logging as _logging
1313
import sys as _sys
14+
import json
1415

1516

1617
def cnn_model_fn(features, labels, mode):
@@ -122,9 +123,8 @@ def _parse_args():
122123
parser.add_argument('--epochs', type=int, default=1)
123124
# Data, model, and output directories
124125
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
125-
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
126+
parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
126127
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
127-
parser.add_argument('--checkpoint_path', type=str, default=os.environ['SM_MODEL_DIR'])
128128

129129
return parser.parse_known_args()
130130

@@ -140,8 +140,12 @@ def _parse_args():
140140
eval_data, eval_labels = _load_testing_data(args.train)
141141

142142
# Create the Estimator
143+
if json.loads(os.environ['SM_TRAINING_ENV'])['additional_framework_parameters'].get('sagemaker_parameter_server_enabled'):
144+
model_dir = args.model_dir
145+
else:
146+
model_dir = os.environ['SM_MODEL_DIR']
143147
mnist_classifier = tf.estimator.Estimator(
144-
model_fn=cnn_model_fn, model_dir=args.checkpoint_path)
148+
model_fn=cnn_model_fn, model_dir=model_dir)
145149

146150
# Set up logging for predictions
147151
# Log the values in the "Softmax" tensor with label "probabilities"

test/unit/test_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def test_single_machine(run_module, single_machine_training_env):
9191

9292
@pytest.mark.skipif(sys.version_info.major != 3,
9393
reason="Skip this for python 2 because of dict key order mismatch")
94+
@patch('tensorflow.train.ClusterSpec')
95+
@patch('tensorflow.train.Server')
9496
@patch('sagemaker_containers.beta.framework.entry_point.run')
9597
def test_train_horovod(run_module, single_machine_training_env):
9698
single_machine_training_env.additional_framework_parameters['sagemaker_mpi_enabled'] = True

0 commit comments

Comments
 (0)