Skip to content

Commit a9e4359

Browse files
authored
Force parameter server to run on CPU (#143)
1 parent 441adb0 commit a9e4359

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,14 @@ def _run_ps(env, cluster):
9393

9494
cluster_spec = tf.train.ClusterSpec(cluster)
9595
task_index = env.hosts.index(env.current_host)
96-
97-
server = tf.train.Server(cluster_spec, job_name='ps', task_index=task_index)
96+
# Force parameter server to run on cpu. Running multiple TensorFlow processes on the same
97+
# GPU is not safe:
98+
# https://stackoverflow.com/questions/46145100/is-it-unsafe-to-run-multiple-tensorflow-processes-on-the-same-gpu
99+
no_gpu_config = tf.ConfigProto(device_count={'GPU': 0})
100+
101+
server = tf.train.Server(
102+
cluster_spec, job_name='ps', task_index=task_index, config=no_gpu_config
103+
)
98104

99105
threading.Thread(target=lambda: server.join()).start()
100106

test/unit/test_training.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from mock import MagicMock, patch
1919
import pytest
20+
import tensorflow as tf
2021

2122
from sagemaker_tensorflow_container import training
2223

@@ -98,7 +99,9 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
9899
'master': ['host1:2222'],
99100
'ps': ['host1:2223', 'host2:2223']})
100101

101-
tf_server.assert_called_with(cluster_spec(), job_name='ps', task_index=0)
102+
tf_server.assert_called_with(
103+
cluster_spec(), job_name='ps', task_index=0, config=tf.ConfigProto(device_count={'GPU': 0})
104+
)
102105
tf_server().join.assert_called_with()
103106

104107
tf_config = '{"cluster": {' \
@@ -128,7 +131,9 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
128131
'master': ['host1:2222'],
129132
'ps': ['host1:2223', 'host2:2223']})
130133

131-
tf_server.assert_called_with(cluster_spec(), job_name='ps', task_index=1)
134+
tf_server.assert_called_with(
135+
cluster_spec(), job_name='ps', task_index=1, config=tf.ConfigProto(device_count={'GPU': 0})
136+
)
132137
tf_server().join.assert_called_with()
133138

134139
tf_config = '{"cluster": {' \

0 commit comments

Comments
 (0)