Skip to content

[CodeExtractor] Allow to use 0 addr space for aggregate arg #66998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion llvm/include/llvm/Transforms/Utils/CodeExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class CodeExtractorAnalysisCache {
// label, if non-empty, otherwise "extracted".
std::string Suffix;

// If true, the outlined function has aggregate argument in zero address
// space.
bool ArgsInZeroAddressSpace;

public:
/// Create a code extractor for a sequence of blocks.
///
Expand All @@ -128,13 +132,16 @@ class CodeExtractorAnalysisCache {
/// Any new allocations will be placed in the AllocationBlock, unless
/// it is null, in which case it will be placed in the entry block of
/// the function from which the code is being extracted.
/// If ArgsInZeroAddressSpace param is set to true, then the aggregate
/// param pointer of the outlined function is declared in zero address
/// space.
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
BranchProbabilityInfo *BPI = nullptr,
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
bool AllowAlloca = false,
BasicBlock *AllocationBlock = nullptr,
std::string Suffix = "");
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);

/// Create a code extractor for a loop body.
///
Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/Transforms/Utils/CodeExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
BranchProbabilityInfo *BPI, AssumptionCache *AC,
bool AllowVarArgs, bool AllowAlloca,
BasicBlock *AllocationBlock, std::string Suffix)
BasicBlock *AllocationBlock, std::string Suffix,
bool ArgsInZeroAddressSpace)
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
AllowVarArgs(AllowVarArgs),
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
Suffix(Suffix) {}
Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}

CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
BlockFrequencyInfo *BFI,
Expand Down Expand Up @@ -866,7 +867,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
StructType *StructTy = nullptr;
if (AggregateArgs && !AggParamTy.empty()) {
StructTy = StructType::get(M->getContext(), AggParamTy);
ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
ParamTy.push_back(PointerType::get(
StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
}

LLVM_DEBUG({
Expand Down Expand Up @@ -1186,8 +1188,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
: &codeReplacer->getParent()->front().front());
params.push_back(Struct);

if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
auto *StructSpaceCast = new AddrSpaceCastInst(
Struct, PointerType ::get(Context, 0), "structArg.ascast");
StructSpaceCast->insertAfter(Struct);
params.push_back(StructSpaceCast);
} else {
params.push_back(Struct);
}
// Store aggregated inputs in the struct.
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
if (inputs.contains(StructValues[i])) {
Expand Down
60 changes: 60 additions & 0 deletions llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,64 @@ TEST(CodeExtractor, PartialAggregateArgs) {
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}

TEST(CodeExtractor, OpenMPAggregateArgs) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8"
target triple = "amdgcn-amd-amdhsa"

define void @foo(ptr %0) {
%2= alloca ptr, align 8, addrspace(5)
%3 = addrspacecast ptr addrspace(5) %2 to ptr
store ptr %0, ptr %3, align 8
%4 = load ptr, ptr %3, align 8
br label %entry

entry:
br label %extract

extract:
store i64 10, ptr %4, align 4
br label %exit

exit:
ret void
}
)ir",
Err, Ctx));
Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};

// Create the CodeExtractor with arguments aggregation enabled.
// Outlined function argument should be declared in 0 address space
// even if the default alloca address space is 5.
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocaBlock*/ &Func->getEntryBlock(),
/* Suffix */ ".outlined",
/* ArgsInZeroAddressSpace */ true);

EXPECT_TRUE(CE.isEligible());

CodeExtractorAnalysisCache CEAC(*Func);
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
BasicBlock *CommonExit = nullptr;
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);

Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
EXPECT_TRUE(Outlined);
EXPECT_EQ(Outlined->arg_size(), 1U);
// Check address space of outlined argument is ptr in address space 0
EXPECT_EQ(Outlined->getArg(0)->getType(),
PointerType::get(M->getContext(), 0));
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
} // end anonymous namespace