Skip to content

Commit f19e6c0

Browse files
use condition_variable and wait_until in nccl dump on timeout (pytorch#120544) (#1365)
Fixes test_c10d_nccl.py -k test_timeout_dumps_timing_enabled_True. Pull Request resolved: pytorch#120544 Approved by: https://github.com/atalman Co-authored-by: Jeff Daily <[email protected]>
1 parent 6ab5d2a commit f19e6c0

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

c10/util/signal_handler.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
#include <sys/syscall.h>
1111
#include <unistd.h>
1212

13+
#include <atomic>
14+
#include <chrono>
15+
#include <condition_variable>
1316
#include <cstdint>
1417
#include <cstdio>
1518
#include <cstdlib>
1619
#include <iostream>
20+
#include <mutex>
1721

1822
#ifdef C10_ANDROID
1923
#ifndef SYS_gettid
@@ -109,8 +113,9 @@ FatalSignalHandler::FatalSignalHandler()
109113
: fatalSignalHandlersInstalled(false),
110114
fatalSignalReceived(false),
111115
fatalSignalName("<UNKNOWN>"),
112-
writingCond(PTHREAD_COND_INITIALIZER),
113-
writingMutex(PTHREAD_MUTEX_INITIALIZER) {}
116+
writingCond(),
117+
writingMutex(),
118+
signalReceived(false) {}
114119

115120
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
116121
FatalSignalHandler::signal_handler FatalSignalHandler::kSignalHandlers[] = {
@@ -157,8 +162,10 @@ void FatalSignalHandler::callPreviousSignalHandler(
157162

158163
// needsLock signals whether we need to lock our writing mutex.
159164
void FatalSignalHandler::stacktraceSignalHandler(bool needsLock) {
165+
std::unique_lock<std::mutex> ul(writingMutex, std::defer_lock);
160166
if (needsLock) {
161-
pthread_mutex_lock(&writingMutex);
167+
ul.lock();
168+
signalReceived = true;
162169
}
163170
pid_t tid = static_cast<pid_t>(syscall(SYS_gettid));
164171
std::string backtrace = fmt::format(
@@ -170,8 +177,8 @@ void FatalSignalHandler::stacktraceSignalHandler(bool needsLock) {
170177
c10::get_backtrace());
171178
std::cerr << backtrace << std::endl;
172179
if (needsLock) {
173-
pthread_mutex_unlock(&writingMutex);
174-
pthread_cond_signal(&writingCond);
180+
ul.unlock();
181+
writingCond.notify_all();
175182
}
176183
}
177184

@@ -204,23 +211,32 @@ void FatalSignalHandler::fatalSignalHandler(int signum) {
204211
pid_t pid = getpid();
205212
pid_t currentTid = static_cast<pid_t>(syscall(SYS_gettid));
206213
struct dirent* entry = nullptr;
207-
pthread_mutex_lock(&writingMutex);
214+
std::unique_lock<std::mutex> ul(writingMutex);
208215
while ((entry = readdir(procDir)) != nullptr) {
209216
if (entry->d_name[0] == '.') {
210217
continue;
211218
}
212219
pid_t tid = atoi(entry->d_name);
213220
// If we've found the current thread then we'll jump into the SIGUSR2
214-
// handler before calling pthread_cond_wait thus deadlocking, so branch
215-
// our directly to the backtrace handler instead of signaling it.
221+
// handler instead of signaling to avoid deadlocking.
216222
if (tid != currentTid) {
223+
signalReceived = false;
217224
syscall(SYS_tgkill, pid, tid, SIGUSR2);
218-
pthread_cond_wait(&writingCond, &writingMutex);
225+
auto now = std::chrono::system_clock::now();
226+
using namespace std::chrono_literals;
227+
// we use wait_until instead of wait because on ROCm there was
228+
// a single thread that wouldn't receive the SIGUSR2
229+
if (std::cv_status::timeout == writingCond.wait_until(ul, now + 2s)) {
230+
if (!signalReceived) {
231+
std::cerr << "signal lost waiting for stacktrace " << pid << ":"
232+
<< tid << std::endl;
233+
break;
234+
}
235+
}
219236
} else {
220237
stacktraceSignalHandler(false);
221238
}
222239
}
223-
pthread_mutex_unlock(&writingMutex);
224240
} else {
225241
perror("Failed to open /proc/self/task");
226242
}

c10/util/signal_handler.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <atomic>
4+
#include <condition_variable>
45
#include <csignal>
56
#include <cstdint>
67
#include <mutex>
@@ -89,8 +90,10 @@ class C10_API FatalSignalHandler {
8990
// This wait condition is used to wait for other threads to finish writing
9091
// their stack trace when in fatal sig handler (we can't use pthread_join
9192
// because there's no way to convert from a tid to a pthread_t).
92-
pthread_cond_t writingCond;
93-
pthread_mutex_t writingMutex;
93+
std::condition_variable writingCond;
94+
std::mutex writingMutex;
95+
// used to indicate if the other thread responded to the signal
96+
bool signalReceived;
9497

9598
struct signal_handler {
9699
const char* name;

test/distributed/test_c10d_nccl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,6 +3955,10 @@ def _check_return_codes(self, elapsed_time):
39553955
@requires_nccl()
39563956
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
39573957
def test_timeout_dumps_on_stuck_ranks(self):
3958+
# need rank0 to crash quicker after detecting timeout
3959+
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
3960+
# restore this env var to its prior default in case another test changed it
3961+
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
39583962

39593963
if self.rank == self.MAIN_PROCESS_RANK:
39603964
# wait for both rank0 and 1 to crash before looking for both ranks' output

0 commit comments

Comments
 (0)