Skip to content

Commit 922ab70

Browse files
authored
[MLIR][OpenMP] Extend omp.private materialization support: dealloc (#90841)
Extends current support for delayed privatization during translation to LLVM IR. This adds support for materlizaing the `dealloc` region in `omp.private` ops when this region contains clean-up/deallocation logic that needs to be executed at the end of the parallel region. This changes the `OMPIRBuilder` slightly to execute the finalization callback **after** the privatization callback. This allows us to collect information about privatized variables on the MLIR and LLVM sides so that we can properly emit deallocation logic.
1 parent f8fedfb commit 922ab70

File tree

3 files changed

+125
-33
lines changed

3 files changed

+125
-33
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,19 +1500,6 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
15001500
};
15011501
}
15021502

1503-
// Adjust the finalization stack, verify the adjustment, and call the
1504-
// finalize function a last time to finalize values between the pre-fini
1505-
// block and the exit block if we left the parallel "the normal way".
1506-
auto FiniInfo = FinalizationStack.pop_back_val();
1507-
(void)FiniInfo;
1508-
assert(FiniInfo.DK == OMPD_parallel &&
1509-
"Unexpected finalization stack state!");
1510-
1511-
Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1512-
1513-
InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1514-
FiniCB(PreFiniIP);
1515-
15161503
OI.OuterAllocaBB = OuterAllocaBlock;
15171504
OI.EntryBB = PRegEntryBB;
15181505
OI.ExitBB = PRegExitBB;
@@ -1637,6 +1624,19 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
16371624
dbgs() << " PBR: " << BB->getName() << "\n";
16381625
});
16391626

1627+
// Adjust the finalization stack, verify the adjustment, and call the
1628+
// finalize function a last time to finalize values between the pre-fini
1629+
// block and the exit block if we left the parallel "the normal way".
1630+
auto FiniInfo = FinalizationStack.pop_back_val();
1631+
(void)FiniInfo;
1632+
assert(FiniInfo.DK == OMPD_parallel &&
1633+
"Unexpected finalization stack state!");
1634+
1635+
Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1636+
1637+
InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1638+
FiniCB(PreFiniIP);
1639+
16401640
// Register the outlined info.
16411641
addOutlineInfo(std::move(OI));
16421642

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/Transforms/Utils/ModuleUtils.h"
3434

3535
#include <any>
36+
#include <iterator>
3637
#include <optional>
3738
#include <utility>
3839

@@ -878,36 +879,40 @@ static void collectReductionInfo(
878879
}
879880

880881
/// handling of DeclareReductionOp's cleanup region
881-
static LogicalResult inlineReductionCleanup(
882-
llvm::SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
883-
llvm::ArrayRef<llvm::Value *> privateReductionVariables,
884-
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) {
885-
for (auto [i, reductionDecl] : llvm::enumerate(reductionDecls)) {
886-
Region &cleanupRegion = reductionDecl.getCleanupRegion();
887-
if (cleanupRegion.empty())
882+
static LogicalResult
883+
inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
884+
llvm::ArrayRef<llvm::Value *> privateVariables,
885+
LLVM::ModuleTranslation &moduleTranslation,
886+
llvm::IRBuilderBase &builder, StringRef regionName,
887+
bool shouldLoadCleanupRegionArg = true) {
888+
for (auto [i, cleanupRegion] : llvm::enumerate(cleanupRegions)) {
889+
if (cleanupRegion->empty())
888890
continue;
889891

890892
// map the argument to the cleanup region
891-
Block &entry = cleanupRegion.front();
893+
Block &entry = cleanupRegion->front();
892894

893895
llvm::Instruction *potentialTerminator =
894896
builder.GetInsertBlock()->empty() ? nullptr
895897
: &builder.GetInsertBlock()->back();
896898
if (potentialTerminator && potentialTerminator->isTerminator())
897899
builder.SetInsertPoint(potentialTerminator);
898-
llvm::Value *reductionVar = builder.CreateLoad(
899-
moduleTranslation.convertType(entry.getArgument(0).getType()),
900-
privateReductionVariables[i]);
900+
llvm::Value *prviateVarValue =
901+
shouldLoadCleanupRegionArg
902+
? builder.CreateLoad(
903+
moduleTranslation.convertType(entry.getArgument(0).getType()),
904+
privateVariables[i])
905+
: privateVariables[i];
901906

902-
moduleTranslation.mapValue(entry.getArgument(0), reductionVar);
907+
moduleTranslation.mapValue(entry.getArgument(0), prviateVarValue);
903908

904-
if (failed(inlineConvertOmpRegions(cleanupRegion, "omp.reduction.cleanup",
905-
builder, moduleTranslation)))
909+
if (failed(inlineConvertOmpRegions(*cleanupRegion, regionName, builder,
910+
moduleTranslation)))
906911
return failure();
907912

908913
// clear block argument mapping in case it needs to be re-created with a
909914
// different source for another use of the same reduction decl
910-
moduleTranslation.forgetMapping(cleanupRegion);
915+
moduleTranslation.forgetMapping(*cleanupRegion);
911916
}
912917
return success();
913918
}
@@ -1110,8 +1115,14 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
11101115
builder.restoreIP(nextInsertionPoint);
11111116

11121117
// after the workshare loop, deallocate private reduction variables
1113-
return inlineReductionCleanup(reductionDecls, privateReductionVariables,
1114-
moduleTranslation, builder);
1118+
SmallVector<Region *> reductionRegions;
1119+
llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
1120+
[](omp::DeclareReductionOp reductionDecl) {
1121+
return &reductionDecl.getCleanupRegion();
1122+
});
1123+
return inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
1124+
moduleTranslation, builder,
1125+
"omp.reduction.cleanup");
11151126
}
11161127

11171128
/// A RAII class that on construction replaces the region arguments of the
@@ -1267,6 +1278,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12671278
}
12681279
};
12691280

1281+
SmallVector<omp::PrivateClauseOp> privatizerClones;
1282+
SmallVector<llvm::Value *> privateVariables;
1283+
12701284
// TODO: Perform appropriate actions according to the data-sharing
12711285
// attribute (shared, private, firstprivate, ...) of variables.
12721286
// Currently shared and private are supported.
@@ -1356,12 +1370,17 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13561370
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
13571371
"op in the parallel region");
13581372
bodyGenStatus = failure();
1373+
privatizerClone.erase();
13591374
} else {
13601375
assert(yieldedValues.size() == 1);
13611376
replacementValue = yieldedValues.front();
1377+
1378+
// Keep the LLVM replacement value and the op clone in case we need to
1379+
// emit cleanup (i.e. deallocation) logic.
1380+
privateVariables.push_back(replacementValue);
1381+
privatizerClones.push_back(privatizerClone);
13621382
}
13631383

1364-
privatizerClone.erase();
13651384
builder.restoreIP(oldIP);
13661385
}
13671386

@@ -1376,8 +1395,25 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
13761395

13771396
// if the reduction has a cleanup region, inline it here to finalize the
13781397
// reduction variables
1379-
if (failed(inlineReductionCleanup(reductionDecls, privateReductionVariables,
1380-
moduleTranslation, builder)))
1398+
SmallVector<Region *> reductionCleanupRegions;
1399+
llvm::transform(reductionDecls, std::back_inserter(reductionCleanupRegions),
1400+
[](omp::DeclareReductionOp reductionDecl) {
1401+
return &reductionDecl.getCleanupRegion();
1402+
});
1403+
if (failed(inlineOmpRegionCleanup(
1404+
reductionCleanupRegions, privateReductionVariables,
1405+
moduleTranslation, builder, "omp.reduction.cleanup")))
1406+
bodyGenStatus = failure();
1407+
1408+
SmallVector<Region *> privateCleanupRegions;
1409+
llvm::transform(privatizerClones, std::back_inserter(privateCleanupRegions),
1410+
[](omp::PrivateClauseOp privatizer) {
1411+
return &privatizer.getDeallocRegion();
1412+
});
1413+
1414+
if (failed(inlineOmpRegionCleanup(
1415+
privateCleanupRegions, privateVariables, moduleTranslation, builder,
1416+
"omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
13811417
bodyGenStatus = failure();
13821418

13831419
builder.restoreIP(oldIP);
@@ -1403,6 +1439,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
14031439
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
14041440
ifCond, numThreads, pbKind, isCancellable));
14051441

1442+
for (mlir::omp::PrivateClauseOp privatizerClone : privatizerClones)
1443+
privatizerClone.erase();
1444+
14061445
return bodyGenStatus;
14071446
}
14081447

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
llvm.func @free(!llvm.ptr)
4+
5+
llvm.func @parallel_op_dealloc(%arg0: !llvm.ptr) {
6+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr) {
7+
%0 = llvm.load %arg2 : !llvm.ptr -> f32
8+
omp.terminator
9+
}
10+
llvm.return
11+
}
12+
13+
omp.private {type = firstprivate} @x.privatizer : !llvm.ptr alloc {
14+
^bb0(%arg0: !llvm.ptr):
15+
%c1 = llvm.mlir.constant(1 : i32) : i32
16+
%0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
17+
omp.yield(%0 : !llvm.ptr)
18+
} copy {
19+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
20+
%0 = llvm.load %arg0 : !llvm.ptr -> f32
21+
llvm.store %0, %arg1 : f32, !llvm.ptr
22+
omp.yield(%arg1 : !llvm.ptr)
23+
} dealloc {
24+
^bb0(%arg0: !llvm.ptr):
25+
%0 = llvm.ptrtoint %arg0 : !llvm.ptr to i64
26+
%c0 = llvm.mlir.constant(0 : i64) : i64
27+
%1 = llvm.icmp "ne" %0, %c0 : i64
28+
llvm.cond_br %1, ^bb1, ^bb2
29+
30+
^bb1:
31+
llvm.call @free(%arg0) : (!llvm.ptr) -> ()
32+
llvm.br ^bb2
33+
34+
^bb2:
35+
omp.yield
36+
}
37+
38+
// CHECK-LABEL: define internal void @parallel_op_dealloc..omp_par
39+
// CHECK: %[[LOCAL_ALLOC:.*]] = alloca float, align 4
40+
41+
// CHECK: omp.par.pre_finalize:
42+
// CHECK: br label %[[DEALLOC_REG_START:.*]]
43+
44+
// CHECK: [[DEALLOC_REG_START]]:
45+
// CHECK: %[[LOCAL_ALLOC_CONV:.*]] = ptrtoint ptr %[[LOCAL_ALLOC]] to i64
46+
// CHECK: %[[COND:.*]] = icmp ne i64 %[[LOCAL_ALLOC_CONV]], 0
47+
// CHECK: br i1 %[[COND]], label %[[DEALLOC_REG_BB1:.*]], label %[[DEALLOC_REG_BB2:.*]]
48+
49+
// CHECK: [[DEALLOC_REG_BB2]]:
50+
51+
// CHECK: [[DEALLOC_REG_BB1]]:
52+
// CHECK-NEXT: call void @free(ptr %[[LOCAL_ALLOC]])
53+
// CHECK-NEXT: br label %[[DEALLOC_REG_BB2]]

0 commit comments

Comments
 (0)