Skip to content

Commit 0aa5ba4

Browse files
authored
[mlir] Fix DistinctAttributeUniquer deleting attribute storage when crash reproduction is enabled (#128566)
Currently, `DistinctAttr` uses an allocator wrapped in a `ThreadLocalCache` to manage attribute storage allocations. This ensures all allocations are freed when the allocator is destroyed. However, this setup can cause use-after-free errors when `mlir::PassManager` runs its passes on a separate thread as a result of crash reproduction being enabled. Distinct attribute storages are created in the child thread's local storage and freed once the thread joins. Attempting to access these attributes after this can result in segmentation faults, such as during printing or alias analysis. Example: This invocation of `mlir-opt` demonstrates the segfault issue due to distinct attributes being created in a child thread and their storage being freed once the thread joins: ``` mlir-opt --mlir-pass-pipeline-crash-reproducer=. --test-distinct-attrs mlir/test/IR/test-builtin-distinct-attrs.mlir ``` This pull request changes the distinct attribute allocator to use different allocators depending on whether or not threading is enabled and whether or not the pass manager is running its passes in a separate thread. If multithreading is disabled, a non thread-local allocator is used. If threading remains enabled and the pass manager invokes its pass pipelines in a child thread, then a non-thread local but synchronised allocator is used. This ensures that the lifetime of allocated storage persists beyond the lifetime of the child thread. I have added two tests for the `-test-distinct-attrs` pass and the `-enable-debug-info-on-llvm-scope` passes that run them with crash reproduction enabled.
1 parent c26ec7e commit 0aa5ba4

File tree

6 files changed

+106
-4
lines changed

6 files changed

+106
-4
lines changed

mlir/include/mlir/IR/MLIRContext.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ class MLIRContext {
153153
disableMultithreading(!enable);
154154
}
155155

156+
/// Set the flag specifying if thread-local storage should be used by storage
157+
/// allocators in this context. Note that disabling mutlithreading implies
158+
/// thread-local storage is also disabled.
159+
void disableThreadLocalStorage(bool disable = true);
160+
void enableThreadLocalStorage(bool enable = true) {
161+
disableThreadLocalStorage(!enable);
162+
}
163+
156164
/// Set a new thread pool to be used in this context. This method requires
157165
/// that multithreading is disabled for this context prior to the call. This
158166
/// allows to share a thread pool across multiple contexts, as well as

mlir/lib/IR/AttributeDetail.h

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/ADT/APFloat.h"
2525
#include "llvm/ADT/PointerIntPair.h"
2626
#include "llvm/Support/TrailingObjects.h"
27+
#include <mutex>
2728

2829
namespace mlir {
2930
namespace detail {
@@ -401,7 +402,8 @@ class DistinctAttributeUniquer {
401402
/// is freed after the destruction of the distinct attribute allocator.
402403
class DistinctAttributeAllocator {
403404
public:
404-
DistinctAttributeAllocator() = default;
405+
DistinctAttributeAllocator(bool threadingIsEnabled)
406+
: threadingIsEnabled(threadingIsEnabled), useThreadLocalAllocator(true) {};
405407

406408
DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
407409
DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
@@ -411,12 +413,49 @@ class DistinctAttributeAllocator {
411413
/// Allocates a distinct attribute storage using a thread local bump pointer
412414
/// allocator to enable synchronization free parallel allocations.
413415
DistinctAttrStorage *allocate(Attribute referencedAttr) {
414-
return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
415-
DistinctAttrStorage(referencedAttr);
416+
if (!useThreadLocalAllocator && threadingIsEnabled) {
417+
std::scoped_lock<std::mutex> lock(allocatorMutex);
418+
return allocateImpl(referencedAttr);
419+
}
420+
return allocateImpl(referencedAttr);
421+
}
422+
423+
/// Sets a flag that stores if multithreading is enabled. The flag is used to
424+
/// decide if locking is needed when using a non thread-safe allocator.
425+
void disableMultiThreading(bool disable = true) {
426+
threadingIsEnabled = !disable;
427+
}
428+
429+
/// Sets a flag to disable using thread local bump pointer allocators and use
430+
/// a single thread-safe allocator. Use this to persist allocated storage
431+
/// beyond the lifetime of a child thread calling this function while ensuring
432+
/// thread-safe allocation.
433+
void disableThreadLocalStorage(bool disable = true) {
434+
useThreadLocalAllocator = !disable;
416435
}
417436

418437
private:
438+
DistinctAttrStorage *allocateImpl(Attribute referencedAttr) {
439+
return new (getAllocatorInUse().Allocate<DistinctAttrStorage>())
440+
DistinctAttrStorage(referencedAttr);
441+
}
442+
443+
/// If threading is disabled on the owning MLIR context, a normal non
444+
/// thread-local, non-thread safe bump pointer allocator is used instead to
445+
/// prevent use-after-free errors whenever attribute storage created on a
446+
/// crash recover thread is accessed after the thread joins.
447+
llvm::BumpPtrAllocator &getAllocatorInUse() {
448+
if (useThreadLocalAllocator)
449+
return allocatorCache.get();
450+
return allocator;
451+
}
452+
419453
ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
454+
llvm::BumpPtrAllocator allocator;
455+
std::mutex allocatorMutex;
456+
457+
bool threadingIsEnabled : 1;
458+
bool useThreadLocalAllocator : 1;
420459
};
421460
} // namespace detail
422461
} // namespace mlir

mlir/lib/IR/MLIRContext.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ class MLIRContextImpl {
268268

269269
public:
270270
MLIRContextImpl(bool threadingIsEnabled)
271-
: threadingIsEnabled(threadingIsEnabled) {
271+
: threadingIsEnabled(threadingIsEnabled),
272+
distinctAttributeAllocator(threadingIsEnabled) {
272273
if (threadingIsEnabled) {
273274
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
274275
threadPool = ownedThreadPool.get();
@@ -596,6 +597,7 @@ void MLIRContext::disableMultithreading(bool disable) {
596597
// Update the threading mode for each of the uniquers.
597598
impl->affineUniquer.disableMultithreading(disable);
598599
impl->attributeUniquer.disableMultithreading(disable);
600+
impl->distinctAttributeAllocator.disableMultiThreading(disable);
599601
impl->typeUniquer.disableMultithreading(disable);
600602

601603
// Destroy thread pool (stop all threads) if it is no longer needed, or create
@@ -717,6 +719,10 @@ bool MLIRContext::isOperationRegistered(StringRef name) {
717719
return RegisteredOperationName::lookup(name, this).has_value();
718720
}
719721

722+
void MLIRContext::disableThreadLocalStorage(bool disable) {
723+
getImpl().distinctAttributeAllocator.disableThreadLocalStorage(disable);
724+
}
725+
720726
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
721727
auto &impl = context->getImpl();
722728
assert(impl.multiThreadedExecutionContext == 0 &&

mlir/lib/Pass/PassCrashRecovery.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,15 @@ struct FileReproducerStream : public mlir::ReproducerStream {
414414

415415
LogicalResult PassManager::runWithCrashRecovery(Operation *op,
416416
AnalysisManager am) {
417+
// Notify the context to disable the use of thread-local storage while the
418+
// pass manager is running in a crash recovery context thread. Re-enable the
419+
// thread local storage upon function exit. This is required to persist any
420+
// attribute storage allocated during passes beyond the lifetime of the
421+
// recovery context thread.
422+
MLIRContext *ctx = getContext();
423+
ctx->disableThreadLocalStorage();
424+
auto guard =
425+
llvm::make_scope_exit([ctx]() { ctx->enableThreadLocalStorage(); });
417426
crashReproGenerator->initialize(getPasses(), op, verifyPasses);
418427

419428
// Safely invoke the passes within a recovery context.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Test that the enable-debug-info-scope-on-llvm-func pass can create its
2+
// distinct attributes when running in the crash reproducer thread.
3+
4+
// RUN: mlir-opt --mlir-disable-threading --mlir-pass-pipeline-crash-reproducer=. \
5+
// RUN: --pass-pipeline="builtin.module(ensure-debug-info-scope-on-llvm-func)" \
6+
// RUN: --mlir-print-debuginfo %s | FileCheck %s
7+
8+
// RUN: mlir-opt --mlir-pass-pipeline-crash-reproducer=. \
9+
// RUN: --pass-pipeline="builtin.module(ensure-debug-info-scope-on-llvm-func)" \
10+
// RUN: --mlir-print-debuginfo %s | FileCheck %s
11+
12+
module {
13+
llvm.func @func_no_debug() {
14+
llvm.return loc(unknown)
15+
} loc(unknown)
16+
} loc(unknown)
17+
18+
// CHECK-LABEL: llvm.func @func_no_debug()
19+
// CHECK: llvm.return loc(#loc
20+
// CHECK: loc(#loc[[LOC:[0-9]+]])
21+
// CHECK: #di_compile_unit = #llvm.di_compile_unit<id = distinct[{{.*}}]<>,
22+
// CHECK: #di_subprogram = #llvm.di_subprogram<id = distinct[{{.*}}]<>
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// This test verifies that when running with crash reproduction enabled, distinct
2+
// attribute storage is not allocated in thread-local storage. Since crash
3+
// reproduction runs the pass manager in a separate thread, using thread-local
4+
// storage for distinct attributes causes use-after-free errors once the thread
5+
// that runs the pass manager joins.
6+
7+
// RUN: mlir-opt --mlir-disable-threading --mlir-pass-pipeline-crash-reproducer=. %s -test-distinct-attrs | FileCheck %s
8+
// RUN: mlir-opt --mlir-pass-pipeline-crash-reproducer=. %s -test-distinct-attrs | FileCheck %s
9+
10+
// CHECK: #[[DIST0:.*]] = distinct[0]<42 : i32>
11+
// CHECK: #[[DIST1:.*]] = distinct[1]<42 : i32>
12+
#distinct = distinct[0]<42 : i32>
13+
14+
// CHECK: @foo_1
15+
func.func @foo_1() {
16+
// CHECK: "test.op"() {distinct.input = #[[DIST0]], distinct.output = #[[DIST1]]}
17+
"test.op"() {distinct.input = #distinct} : () -> ()
18+
}

0 commit comments

Comments
 (0)