Skip to content

Commit eee8dd9

Browse files
[CodeExtractor] Allow to use 0 addr space for aggregate arg (#66998)
The user of CodeExtractor should be able to specify that the aggregate argument should be passed as a pointer in zero address space. CodeExtractor is used to generate outlined functions required by OpenMP runtime. The arguments of the outlined functions for OpenMP GPU code are in 0 address space. 0 address space does not need to be the default address space for GPU device. That's why there is a need to allow the user of CodeExtractor to specify, that the allocated aggregate parameter is passed as pointer in zero address space.
1 parent ddf1de2 commit eee8dd9

File tree

3 files changed

+81
-5
lines changed

3 files changed

+81
-5
lines changed

llvm/include/llvm/Transforms/Utils/CodeExtractor.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class CodeExtractorAnalysisCache {
114114
// label, if non-empty, otherwise "extracted".
115115
std::string Suffix;
116116

117+
// If true, the outlined function has aggregate argument in zero address
118+
// space.
119+
bool ArgsInZeroAddressSpace;
120+
117121
public:
118122
/// Create a code extractor for a sequence of blocks.
119123
///
@@ -128,13 +132,16 @@ class CodeExtractorAnalysisCache {
128132
/// Any new allocations will be placed in the AllocationBlock, unless
129133
/// it is null, in which case it will be placed in the entry block of
130134
/// the function from which the code is being extracted.
135+
/// If ArgsInZeroAddressSpace param is set to true, then the aggregate
136+
/// param pointer of the outlined function is declared in zero address
137+
/// space.
131138
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
132139
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
133140
BranchProbabilityInfo *BPI = nullptr,
134141
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
135142
bool AllowAlloca = false,
136143
BasicBlock *AllocationBlock = nullptr,
137-
std::string Suffix = "");
144+
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
138145

139146
/// Create a code extractor for a loop body.
140147
///

llvm/lib/Transforms/Utils/CodeExtractor.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
245245
bool AggregateArgs, BlockFrequencyInfo *BFI,
246246
BranchProbabilityInfo *BPI, AssumptionCache *AC,
247247
bool AllowVarArgs, bool AllowAlloca,
248-
BasicBlock *AllocationBlock, std::string Suffix)
248+
BasicBlock *AllocationBlock, std::string Suffix,
249+
bool ArgsInZeroAddressSpace)
249250
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
250251
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
251252
AllowVarArgs(AllowVarArgs),
252253
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
253-
Suffix(Suffix) {}
254+
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
254255

255256
CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
256257
BlockFrequencyInfo *BFI,
@@ -866,7 +867,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
866867
StructType *StructTy = nullptr;
867868
if (AggregateArgs && !AggParamTy.empty()) {
868869
StructTy = StructType::get(M->getContext(), AggParamTy);
869-
ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
870+
ParamTy.push_back(PointerType::get(
871+
StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
870872
}
871873

872874
LLVM_DEBUG({
@@ -1187,8 +1189,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
11871189
StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
11881190
AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
11891191
: &codeReplacer->getParent()->front().front());
1190-
params.push_back(Struct);
11911192

1193+
if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
1194+
auto *StructSpaceCast = new AddrSpaceCastInst(
1195+
Struct, PointerType ::get(Context, 0), "structArg.ascast");
1196+
StructSpaceCast->insertAfter(Struct);
1197+
params.push_back(StructSpaceCast);
1198+
} else {
1199+
params.push_back(Struct);
1200+
}
11921201
// Store aggregated inputs in the struct.
11931202
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
11941203
if (inputs.contains(StructValues[i])) {

llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,4 +555,64 @@ TEST(CodeExtractor, PartialAggregateArgs) {
555555
EXPECT_FALSE(verifyFunction(*Outlined));
556556
EXPECT_FALSE(verifyFunction(*Func));
557557
}
558+
559+
TEST(CodeExtractor, OpenMPAggregateArgs) {
560+
LLVMContext Ctx;
561+
SMDiagnostic Err;
562+
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
563+
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8"
564+
target triple = "amdgcn-amd-amdhsa"
565+
566+
define void @foo(ptr %0) {
567+
%2= alloca ptr, align 8, addrspace(5)
568+
%3 = addrspacecast ptr addrspace(5) %2 to ptr
569+
store ptr %0, ptr %3, align 8
570+
%4 = load ptr, ptr %3, align 8
571+
br label %entry
572+
573+
entry:
574+
br label %extract
575+
576+
extract:
577+
store i64 10, ptr %4, align 4
578+
br label %exit
579+
580+
exit:
581+
ret void
582+
}
583+
)ir",
584+
Err, Ctx));
585+
Function *Func = M->getFunction("foo");
586+
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
587+
588+
// Create the CodeExtractor with arguments aggregation enabled.
589+
// Outlined function argument should be declared in 0 address space
590+
// even if the default alloca address space is 5.
591+
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
592+
/* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
593+
/* BranchProbabilityInfo */ nullptr,
594+
/* AssumptionCache */ nullptr,
595+
/* AllowVarArgs */ true,
596+
/* AllowAlloca */ true,
597+
/* AllocaBlock*/ &Func->getEntryBlock(),
598+
/* Suffix */ ".outlined",
599+
/* ArgsInZeroAddressSpace */ true);
600+
601+
EXPECT_TRUE(CE.isEligible());
602+
603+
CodeExtractorAnalysisCache CEAC(*Func);
604+
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
605+
BasicBlock *CommonExit = nullptr;
606+
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
607+
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
608+
609+
Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
610+
EXPECT_TRUE(Outlined);
611+
EXPECT_EQ(Outlined->arg_size(), 1U);
612+
// Check address space of outlined argument is ptr in address space 0
613+
EXPECT_EQ(Outlined->getArg(0)->getType(),
614+
PointerType::get(M->getContext(), 0));
615+
EXPECT_FALSE(verifyFunction(*Outlined));
616+
EXPECT_FALSE(verifyFunction(*Func));
617+
}
558618
} // end anonymous namespace

0 commit comments

Comments
 (0)