Skip to content

Commit f5beede

Browse files
jeffdailyjithunnair-amd
authored andcommitted
parallel_apply should forward current streams to worker threads (pytorch#78824)
pytorch#71033 moved test_data_parallel_module et al under `instantiate_device_type_tests`. This had the side effect of now running the tests on a non-default stream. The parallel_apply creates new threads, one per device, but does not forward the thread local current streams from the parent thread. This defaults the new per-device threads to use the null stream. The null stream will not sync with the non-default non-blocking streams, resulting in errors when these tests assert tensors are equal. CC @janeyx99 Pull Request resolved: pytorch#78824 Approved by: https://github.com/pruthvistony, https://github.com/janeyx99
1 parent 6b07949 commit f5beede

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

torch/nn/parallel/parallel_apply.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,19 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
4545
else:
4646
devices = [None] * len(modules)
4747
devices = [_get_device_index(x, True) for x in devices]
48+
streams = [torch.cuda.current_stream(x) for x in devices]
4849
lock = threading.Lock()
4950
results = {}
5051
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
5152

52-
def _worker(i, module, input, kwargs, device=None):
53+
def _worker(i, module, input, kwargs, device=None, stream=None):
5354
torch.set_grad_enabled(grad_enabled)
5455
if device is None:
5556
device = get_a_var(input).get_device()
57+
if stream is None:
58+
stream = torch.cuda.current_stream(device)
5659
try:
57-
with torch.cuda.device(device), autocast(enabled=autocast_enabled):
60+
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
5861
# this also avoids accidental slicing of `input` if it is a Tensor
5962
if not isinstance(input, (list, tuple)):
6063
input = (input,)
@@ -68,16 +71,16 @@ def _worker(i, module, input, kwargs, device=None):
6871

6972
if len(modules) > 1:
7073
threads = [threading.Thread(target=_worker,
71-
args=(i, module, input, kwargs, device))
72-
for i, (module, input, kwargs, device) in
73-
enumerate(zip(modules, inputs, kwargs_tup, devices))]
74+
args=(i, module, input, kwargs, device, stream))
75+
for i, (module, input, kwargs, device, stream) in
76+
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
7477

7578
for thread in threads:
7679
thread.start()
7780
for thread in threads:
7881
thread.join()
7982
else:
80-
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
83+
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
8184

8285
outputs = []
8386
for i in range(len(inputs)):

0 commit comments

Comments
 (0)