Skip to content

Commit dcf4ca5

Browse files
authored
[OpenMP][MLIR][OMPIRBuilder] Add a small optional constant alloca raise function pass to finalize, utilised in convertTarget (llvm#78818)
This patch seeks to add a mechanism to raise constant (not ConstantExpr or runtime/dynamic) sized allocations into the entry block for select functions that have been inserted into a list for processing. This processing occurs during the finalize call, after OutlinedInfo regions have completed. This currently has only been utilised for createOutlinedFunction, which is triggered for TargetOp generation in the OpenMP MLIR dialect lowering to LLVM-IR. This currently is required for Target kernels generated by createOutlinedFunction to avoid subsequent optimization passes doing some unintentional malformed optimizations for AMD kernels (unsure if it occurs for other vendors). If the allocas are generated inside of the kernel and are not in the entry block and are subsequently passed to a function this can lead to required instructions being erased or manipulated in a way that causes the kernel to run into a HSA access error. This fix is related to a series of problems found in: llvm#74603 This problem primarily presents itself for Flang's HLFIR AssignOp currently, when utilised with a scalar temporary constant on the RHS and a descriptor type on the LHS. It will generate a call to a runtime function, wrap the RHS temporary in a newly allocated descriptor (an llvm struct), and pass both the LHS and RHS descriptor into the runtime function call. This will currently be embedded into the middle of the target region in the user entry block, which means the allocas are also embedded in the middle, which seems to pose issues when later passes are executed. This issue may present itself in other HLFIR operations or unrelated operations that generate allocas as a by product, but for the moment, this one test case is the only scenario I've found this problem. Perhaps this is not the appropriate fix, I am very open to other suggestions, I've tried a few others (at varying levels of the flang/mlir compiler flow), but this one is the smallest and least intrusive change set. The other two, that come to mind (but I've not fully looked into, the former I tried a little with blocks but it had a few issues I'd need to think through): - Having a proper alloca only block (or region) generated for TargetOps that we could merge into the entry block that's generated by convertTarget's createOutlinedFunction. - Or diverging a little from Clang's current target generation and using the CodeExtractor to generate the user code as an outlined function region invoked from the kernel we make, with our kernel arguments passed into it. Similar to the current parallel generation. I am not sure how well this would intermingle with the existing parallel generation though that's layered in. Both of these methods seem like quite a divergence from the current status quo, which I am not entirely sure is merited for the small test this change aims to fix.
1 parent 47aee8b commit dcf4ca5

File tree

4 files changed

+249
-0
lines changed

4 files changed

+249
-0
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,11 @@ class OpenMPIRBuilder {
15061506
/// Collection of regions that need to be outlined during finalization.
15071507
SmallVector<OutlineInfo, 16> OutlineInfos;
15081508

1509+
/// A collection of candidate target functions that's constant allocas will
1510+
/// attempt to be raised on a call of finalize after all currently enqueued
1511+
/// outline info's have been processed.
1512+
SmallVector<llvm::Function *, 16> ConstantAllocaRaiseCandidates;
1513+
15091514
/// Collection of owned canonical loop objects that eventually need to be
15101515
/// free'd.
15111516
std::forward_list<CanonicalLoopInfo> LoopInfos;

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,29 @@ Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
633633

634634
void OpenMPIRBuilder::initialize() { initializeTypes(M); }
635635

636+
static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
637+
Function *Function) {
638+
BasicBlock &EntryBlock = Function->getEntryBlock();
639+
Instruction *MoveLocInst = EntryBlock.getFirstNonPHI();
640+
641+
// Loop over blocks looking for constant allocas, skipping the entry block
642+
// as any allocas there are already in the desired location.
643+
for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
644+
Block++) {
645+
for (auto Inst = Block->getReverseIterator()->begin();
646+
Inst != Block->getReverseIterator()->end();) {
647+
if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
648+
Inst++;
649+
if (!isa<ConstantData>(AllocaInst->getArraySize()))
650+
continue;
651+
AllocaInst->moveBeforePreserving(MoveLocInst);
652+
} else {
653+
Inst++;
654+
}
655+
}
656+
}
657+
}
658+
636659
void OpenMPIRBuilder::finalize(Function *Fn) {
637660
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
638661
SmallVector<BasicBlock *, 32> Blocks;
@@ -737,6 +760,28 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
737760
// Remove work items that have been completed.
738761
OutlineInfos = std::move(DeferredOutlines);
739762

763+
// The createTarget functions embeds user written code into
764+
// the target region which may inject allocas which need to
765+
// be moved to the entry block of our target or risk malformed
766+
// optimisations by later passes, this is only relevant for
767+
// the device pass which appears to be a little more delicate
768+
// when it comes to optimisations (however, we do not block on
769+
// that here, it's up to the inserter to the list to do so).
770+
// This notbaly has to occur after the OutlinedInfo candidates
771+
// have been extracted so we have an end product that will not
772+
// be implicitly adversely affected by any raises unless
773+
// intentionally appended to the list.
774+
// NOTE: This only does so for ConstantData, it could be extended
775+
// to ConstantExpr's with further effort, however, they should
776+
// largely be folded when they get here. Extending it to runtime
777+
// defined/read+writeable allocation sizes would be non-trivial
778+
// (need to factor in movement of any stores to variables the
779+
// allocation size depends on, as well as the usual loads,
780+
// otherwise it'll yield the wrong result after movement) and
781+
// likely be more suitable as an LLVM optimisation pass.
782+
for (Function *F : ConstantAllocaRaiseCandidates)
783+
raiseUserConstantDataAllocasToEntryBlock(Builder, F);
784+
740785
EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
741786
[](EmitMetadataErrorKind Kind,
742787
const TargetRegionEntryInfo &EntryInfo) -> void {
@@ -5043,6 +5088,12 @@ static Function *createOutlinedFunction(
50435088

50445089
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
50455090

5091+
// As we embed the user code in the middle of our target region after we
5092+
// generate entry code, we must move what allocas we can into the entry
5093+
// block to avoid possible breaking optimisations for device
5094+
if (OMPBuilder.Config.isTargetDevice())
5095+
OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
5096+
50465097
// Insert target deinit call in the device compilation pass.
50475098
Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
50485099
if (OMPBuilder.Config.isTargetDevice())

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5989,6 +5989,156 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
59895989
EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
59905990
}
59915991

5992+
TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
5993+
OpenMPIRBuilder OMPBuilder(*M);
5994+
OMPBuilder.setConfig(
5995+
OpenMPIRBuilderConfig(true, false, false, false, false, false, false));
5996+
OMPBuilder.initialize();
5997+
5998+
F->setName("func");
5999+
IRBuilder<> Builder(BB);
6000+
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
6001+
6002+
LoadInst *Value = nullptr;
6003+
StoreInst *TargetStore = nullptr;
6004+
llvm::SmallVector<llvm::Value *, 1> CapturedArgs = {
6005+
Constant::getNullValue(PointerType::get(Ctx, 0))};
6006+
6007+
auto SimpleArgAccessorCB =
6008+
[&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
6009+
llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
6010+
llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
6011+
if (!OMPBuilder.Config.isTargetDevice()) {
6012+
RetVal = cast<llvm::Value>(&Arg);
6013+
return CodeGenIP;
6014+
}
6015+
6016+
Builder.restoreIP(AllocaIP);
6017+
6018+
llvm::Value *Addr = Builder.CreateAlloca(
6019+
Arg.getType()->isPointerTy()
6020+
? Arg.getType()
6021+
: Type::getInt64Ty(Builder.getContext()),
6022+
OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
6023+
llvm::Value *AddrAscast =
6024+
Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
6025+
Builder.CreateStore(&Arg, AddrAscast);
6026+
6027+
Builder.restoreIP(CodeGenIP);
6028+
6029+
RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
6030+
6031+
return Builder.saveIP();
6032+
};
6033+
6034+
llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
6035+
auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
6036+
-> llvm::OpenMPIRBuilder::MapInfosTy & {
6037+
CreateDefaultMapInfos(OMPBuilder, CapturedArgs, CombinedInfos);
6038+
return CombinedInfos;
6039+
};
6040+
6041+
llvm::Value *RaiseAlloca = nullptr;
6042+
6043+
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
6044+
OpenMPIRBuilder::InsertPointTy CodeGenIP)
6045+
-> OpenMPIRBuilder::InsertPointTy {
6046+
Builder.restoreIP(CodeGenIP);
6047+
RaiseAlloca = Builder.CreateAlloca(Builder.getInt32Ty());
6048+
Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]);
6049+
TargetStore = Builder.CreateStore(Value, RaiseAlloca);
6050+
return Builder.saveIP();
6051+
};
6052+
6053+
IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
6054+
F->getEntryBlock().getFirstInsertionPt());
6055+
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
6056+
/*Line=*/3, /*Count=*/0);
6057+
6058+
Builder.restoreIP(
6059+
OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
6060+
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6061+
BodyGenCB, SimpleArgAccessorCB));
6062+
6063+
Builder.CreateRetVoid();
6064+
OMPBuilder.finalize();
6065+
6066+
// Check outlined function
6067+
EXPECT_FALSE(verifyModule(*M, &errs()));
6068+
EXPECT_NE(TargetStore, nullptr);
6069+
Function *OutlinedFn = TargetStore->getFunction();
6070+
EXPECT_NE(F, OutlinedFn);
6071+
6072+
EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
6073+
// Account for the "implicit" first argument.
6074+
EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
6075+
EXPECT_EQ(OutlinedFn->arg_size(), 2U);
6076+
EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
6077+
6078+
// Check entry block, to see if we have raised our alloca
6079+
// from the body to the entry block.
6080+
auto &EntryBlock = OutlinedFn->getEntryBlock();
6081+
6082+
// Check that we have moved our alloca created in the
6083+
// BodyGenCB function, to the top of the function.
6084+
Instruction *Alloca1 = EntryBlock.getFirstNonPHI();
6085+
EXPECT_NE(Alloca1, nullptr);
6086+
EXPECT_TRUE(isa<AllocaInst>(Alloca1));
6087+
EXPECT_EQ(Alloca1, RaiseAlloca);
6088+
6089+
// Verify we have not altered the rest of the function
6090+
// inappropriately with our alloca movement.
6091+
auto *Alloca2 = Alloca1->getNextNode();
6092+
EXPECT_TRUE(isa<AllocaInst>(Alloca2));
6093+
auto *Store2 = Alloca2->getNextNode();
6094+
EXPECT_TRUE(isa<StoreInst>(Store2));
6095+
6096+
auto *InitCall = dyn_cast<CallInst>(Store2->getNextNode());
6097+
EXPECT_NE(InitCall, nullptr);
6098+
EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
6099+
EXPECT_EQ(InitCall->arg_size(), 2U);
6100+
EXPECT_TRUE(isa<GlobalVariable>(InitCall->getArgOperand(0)));
6101+
auto *KernelEnvGV = cast<GlobalVariable>(InitCall->getArgOperand(0));
6102+
EXPECT_TRUE(isa<ConstantStruct>(KernelEnvGV->getInitializer()));
6103+
auto *KernelEnvC = cast<ConstantStruct>(KernelEnvGV->getInitializer());
6104+
EXPECT_TRUE(isa<ConstantStruct>(KernelEnvC->getAggregateElement(0U)));
6105+
auto *ConfigC = cast<ConstantStruct>(KernelEnvC->getAggregateElement(0U));
6106+
EXPECT_EQ(ConfigC->getAggregateElement(0U),
6107+
ConstantInt::get(Type::getInt8Ty(Ctx), true));
6108+
EXPECT_EQ(ConfigC->getAggregateElement(1U),
6109+
ConstantInt::get(Type::getInt8Ty(Ctx), true));
6110+
EXPECT_EQ(ConfigC->getAggregateElement(2U),
6111+
ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
6112+
6113+
auto *EntryBlockBranch = EntryBlock.getTerminator();
6114+
EXPECT_NE(EntryBlockBranch, nullptr);
6115+
EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U);
6116+
6117+
// Check user code block
6118+
auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
6119+
EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
6120+
auto *Load1 = UserCodeBlock->getFirstNonPHI();
6121+
EXPECT_TRUE(isa<LoadInst>(Load1));
6122+
auto *Load2 = Load1->getNextNode();
6123+
EXPECT_TRUE(isa<LoadInst>(Load2));
6124+
EXPECT_EQ(Load2, Value);
6125+
EXPECT_EQ(Load2->getNextNode(), TargetStore);
6126+
auto *Deinit = TargetStore->getNextNode();
6127+
EXPECT_NE(Deinit, nullptr);
6128+
6129+
auto *DeinitCall = dyn_cast<CallInst>(Deinit);
6130+
EXPECT_NE(DeinitCall, nullptr);
6131+
EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit");
6132+
EXPECT_EQ(DeinitCall->arg_size(), 0U);
6133+
6134+
EXPECT_TRUE(isa<ReturnInst>(DeinitCall->getNextNode()));
6135+
6136+
// Check exit block
6137+
auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
6138+
EXPECT_EQ(ExitBlock->getName(), "worker.exit");
6139+
EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
6140+
}
6141+
59926142
TEST_F(OpenMPIRBuilderTest, CreateTask) {
59936143
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
59946144
OpenMPIRBuilder OMPBuilder(*M);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// A small condensed version of a problem requiring constant alloca raising in
4+
// Target Region Entries for user injected code, found in an issue in the Flang
5+
// compiler. Certain LLVM IR optimisation passes will perform runtime breaking
6+
// transformations on allocations not found to be in the entry block, current
7+
// OpenMP dialect lowering of TargetOp's will inject user allocations after
8+
// compiler generated entry code, in a seperate block, this test checks that
9+
// a small function which attempts to raise some of these (specifically
10+
// constant sized) allocations performs its task reasonably in these
11+
// scenarios.
12+
13+
module attributes {omp.is_target_device = true} {
14+
llvm.func @_QQmain() attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
15+
%1 = llvm.mlir.constant(1 : i64) : i64
16+
%2 = llvm.alloca %1 x !llvm.struct<(ptr)> : (i64) -> !llvm.ptr
17+
%3 = omp.map_info var_ptr(%2 : !llvm.ptr, !llvm.struct<(ptr)>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr
18+
omp.target map_entries(%3 -> %arg0 : !llvm.ptr) {
19+
^bb0(%arg0: !llvm.ptr):
20+
%4 = llvm.mlir.constant(1 : i32) : i32
21+
%5 = llvm.alloca %4 x !llvm.struct<(ptr)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
22+
%6 = llvm.mlir.constant(50 : i32) : i32
23+
%7 = llvm.mlir.constant(1 : i64) : i64
24+
%8 = llvm.alloca %7 x i32 : (i64) -> !llvm.ptr
25+
llvm.store %6, %8 : i32, !llvm.ptr
26+
%9 = llvm.mlir.undef : !llvm.struct<(ptr)>
27+
%10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr)>
28+
llvm.store %10, %5 : !llvm.struct<(ptr)>, !llvm.ptr
29+
%88 = llvm.call @_ExternalCall(%arg0, %5) : (!llvm.ptr, !llvm.ptr) -> !llvm.struct<()>
30+
omp.terminator
31+
}
32+
llvm.return
33+
}
34+
llvm.func @_ExternalCall(!llvm.ptr, !llvm.ptr) -> !llvm.struct<()>
35+
}
36+
37+
// CHECK: define weak_odr protected void @{{.*}}QQmain_l{{.*}}({{.*}}, {{.*}}) {
38+
// CHECK-NEXT: entry:
39+
// CHECK-NEXT: %[[MOVED_ALLOCA1:.*]] = alloca { ptr }, align 8
40+
// CHECK-NEXT: %[[MOVED_ALLOCA2:.*]] = alloca i32, i64 1, align 4
41+
// CHECK-NEXT: %[[MAP_ARG_ALLOCA:.*]] = alloca ptr, align 8
42+
43+
// CHECK: user_code.entry: ; preds = %entry

0 commit comments

Comments
 (0)