Skip to content

Commit b864943

Browse files
authored
Use multiprocessing.Process to launch parameter server (#203)
* Use multiprocessing.Process to launch parameter server * Force ps process to run on cpu
1 parent dfee185 commit b864943

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/tf_container/train_entry_point.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import subprocess
1818
import time
19-
from threading import Thread
19+
from multiprocessing import Process
2020

2121
import tensorflow as tf
2222

@@ -64,11 +64,15 @@ def _run_ps_server(current_host, hosts, tf_config):
6464
def start_ps_server(current_host, hosts, tf_config):
6565
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])
6666
task_index = hosts.index(current_host)
67-
server = tf.train.Server(cluster_spec, job_name='ps', task_index=task_index)
67+
# Force parameter server to run on cpu. Running multiple TensorFlow processes on the same
68+
# GPU is not safe:
69+
# https://stackoverflow.com/questions/46145100/is-it-unsafe-to-run-multiple-tensorflow-processes-on-the-same-gpu
70+
no_gpu_config = tf.ConfigProto(device_count={'GPU': 0})
71+
server = tf.train.Server(cluster_spec, job_name='ps', task_index=task_index, config=no_gpu_config)
6872
server.join()
6973

70-
t = Thread(target=start_ps_server, args=(current_host, hosts, tf_config))
71-
t.start()
74+
p = Process(target=start_ps_server, args=(current_host, hosts, tf_config))
75+
p.start()
7276

7377

7478
def _get_default_training_params(env):

0 commit comments

Comments
 (0)