Skip to content

Commit f0c9fae

Browse files
committed
[WIP][OpenMP] Remove dependency on libffi from offloading runtime
Summary: This patch attempts to remove the dependency on `libffi` by instead emitting the host / CPU kernels using an aggregate struct made from the captured context. This callows us to have a fixed function prototype we can call directly rather than requiring an extra library to decode the ABI to call a function with N (non variadic) arguments.
1 parent c618ae1 commit f0c9fae

40 files changed

+6393
-4688
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5927,12 +5927,16 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
59275927

59285928
CodeGenFunction CGF(CGM, true);
59295929
llvm::OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
5930-
[&CGF, &D, &CodeGen](StringRef EntryFnName) {
5930+
[&CGF, &D, &CodeGen, this](StringRef EntryFnName) {
59315931
const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
59325932

59335933
CGOpenMPTargetRegionInfo CGInfo(CS, CodeGen, EntryFnName);
59345934
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
5935-
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D.getBeginLoc());
5935+
if (CGM.getLangOpts().OpenMPIsTargetDevice && !isGPU())
5936+
return CGF.GenerateOpenMPCapturedStmtFunctionAggregate(
5937+
CS, D.getBeginLoc());
5938+
else
5939+
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D.getBeginLoc());
59365940
};
59375941

59385942
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,102 @@ static llvm::Function *emitOutlinedFunctionPrologue(
629629
return F;
630630
}
631631

632+
static llvm::Function *emitOutlinedFunctionPrologueAggregate(
633+
CodeGenFunction &CGF, FunctionArgList &Args,
634+
llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
635+
&LocalAddrs,
636+
llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>>
637+
&VLASizes,
638+
llvm::Value *&CXXThisValue, const CapturedStmt &CS, SourceLocation Loc,
639+
StringRef FunctionName) {
640+
const CapturedDecl *CD = CS.getCapturedDecl();
641+
const RecordDecl *RD = CS.getCapturedRecordDecl();
642+
643+
CXXThisValue = nullptr;
644+
// Build the argument list.
645+
CodeGenModule &CGM = CGF.CGM;
646+
ASTContext &Ctx = CGM.getContext();
647+
Args.append(CD->param_begin(), CD->param_end());
648+
649+
// Create the function declaration.
650+
const CGFunctionInfo &FuncInfo =
651+
CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, Args);
652+
llvm::FunctionType *FuncLLVMTy = CGM.getTypes().GetFunctionType(FuncInfo);
653+
654+
auto *F =
655+
llvm::Function::Create(FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
656+
FunctionName, &CGM.getModule());
657+
CGM.SetInternalFunctionAttributes(CD, F, FuncInfo);
658+
if (CD->isNothrow())
659+
F->setDoesNotThrow();
660+
F->setDoesNotRecurse();
661+
662+
// Generate the function.
663+
CGF.StartFunction(CD, Ctx.VoidTy, F, FuncInfo, Args, Loc, Loc);
664+
Address ContextAddr = CGF.GetAddrOfLocalVar(CD->getContextParam());
665+
llvm::Value *ContextV = CGF.Builder.CreateLoad(ContextAddr);
666+
LValue ContextLV = CGF.MakeNaturalAlignAddrLValue(
667+
ContextV, CGM.getContext().getTagDeclType(RD));
668+
auto I = CS.captures().begin();
669+
for (const FieldDecl *FD : RD->fields()) {
670+
LValue FieldLV = CGF.EmitLValueForFieldInitialization(ContextLV, FD);
671+
// Do not map arguments if we emit function with non-original types.
672+
Address LocalAddr = FieldLV.getAddress(CGF);
673+
// If we are capturing a pointer by copy we don't need to do anything, just
674+
// use the value that we get from the arguments.
675+
if (I->capturesVariableByCopy() && FD->getType()->isAnyPointerType()) {
676+
const VarDecl *CurVD = I->getCapturedVar();
677+
LocalAddrs.insert({FD, {CurVD, LocalAddr}});
678+
++I;
679+
continue;
680+
}
681+
682+
LValue ArgLVal =
683+
CGF.MakeAddrLValue(LocalAddr, FD->getType(), AlignmentSource::Decl);
684+
if (FD->hasCapturedVLAType()) {
685+
llvm::Value *ExprArg = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
686+
const VariableArrayType *VAT = FD->getCapturedVLAType();
687+
VLASizes.try_emplace(FD, VAT->getSizeExpr(), ExprArg);
688+
} else if (I->capturesVariable()) {
689+
const VarDecl *Var = I->getCapturedVar();
690+
QualType VarTy = Var->getType();
691+
Address ArgAddr = ArgLVal.getAddress(CGF);
692+
if (ArgLVal.getType()->isLValueReferenceType()) {
693+
ArgAddr = CGF.EmitLoadOfReference(ArgLVal);
694+
} else if (!VarTy->isVariablyModifiedType() || !VarTy->isPointerType()) {
695+
assert(ArgLVal.getType()->isPointerType());
696+
ArgAddr = CGF.EmitLoadOfPointer(
697+
ArgAddr, ArgLVal.getType()->castAs<PointerType>());
698+
}
699+
LocalAddrs.insert(
700+
{FD,
701+
{Var, Address(ArgAddr.getBasePointer(), ArgAddr.getElementType(),
702+
Ctx.getDeclAlign(Var))}});
703+
} else if (I->capturesVariableByCopy()) {
704+
assert(!FD->getType()->isAnyPointerType() &&
705+
"Not expecting a captured pointer.");
706+
const VarDecl *Var = I->getCapturedVar();
707+
Address CopyAddr = CGF.CreateMemTemp(FD->getType(), Ctx.getDeclAlign(FD),
708+
Var->getName());
709+
LValue CopyLVal =
710+
CGF.MakeAddrLValue(CopyAddr, FD->getType(), AlignmentSource::Decl);
711+
712+
RValue ArgRVal = CGF.EmitLoadOfLValue(ArgLVal, I->getLocation());
713+
CGF.EmitStoreThroughLValue(ArgRVal, CopyLVal);
714+
715+
LocalAddrs.insert({FD, {Var, CopyAddr}});
716+
} else {
717+
// If 'this' is captured, load it into CXXThisValue.
718+
assert(I->capturesThis());
719+
CXXThisValue = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
720+
LocalAddrs.insert({FD, {nullptr, ArgLVal.getAddress(CGF)}});
721+
}
722+
++I;
723+
}
724+
725+
return F;
726+
}
727+
632728
llvm::Function *
633729
CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
634730
SourceLocation Loc) {
@@ -711,6 +807,36 @@ CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
711807
return WrapperF;
712808
}
713809

810+
llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunctionAggregate(
811+
const CapturedStmt &S, SourceLocation Loc) {
812+
assert(
813+
CapturedStmtInfo &&
814+
"CapturedStmtInfo should be set when generating the captured function");
815+
const CapturedDecl *CD = S.getCapturedDecl();
816+
// Build the argument list.
817+
FunctionArgList Args;
818+
llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
819+
llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
820+
StringRef FunctionName = CapturedStmtInfo->getHelperName();
821+
llvm::Function *F = emitOutlinedFunctionPrologueAggregate(
822+
*this, Args, LocalAddrs, VLASizes, CXXThisValue, S, Loc, FunctionName);
823+
CodeGenFunction::OMPPrivateScope LocalScope(*this);
824+
for (const auto &LocalAddrPair : LocalAddrs) {
825+
if (LocalAddrPair.second.first) {
826+
LocalScope.addPrivate(LocalAddrPair.second.first,
827+
LocalAddrPair.second.second);
828+
}
829+
}
830+
(void)LocalScope.Privatize();
831+
for (const auto &VLASizePair : VLASizes)
832+
VLASizeMap[VLASizePair.second.first] = VLASizePair.second.second;
833+
PGO.assignRegionCounters(GlobalDecl(CD), F);
834+
CapturedStmtInfo->EmitBody(*this, CD->getBody());
835+
(void)LocalScope.ForceCleanup();
836+
FinishFunction(CD->getBodyRBrace());
837+
return F;
838+
}
839+
714840
//===----------------------------------------------------------------------===//
715841
// OpenMP Directive Emission
716842
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3643,6 +3643,9 @@ class CodeGenFunction : public CodeGenTypeCache {
36433643
Address GenerateCapturedStmtArgument(const CapturedStmt &S);
36443644
llvm::Function *GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
36453645
SourceLocation Loc);
3646+
llvm::Function *
3647+
GenerateOpenMPCapturedStmtFunctionAggregate(const CapturedStmt &S,
3648+
SourceLocation Loc);
36463649
void GenerateOpenMPCapturedVars(const CapturedStmt &S,
36473650
SmallVectorImpl<llvm::Value *> &CapturedVars);
36483651
void emitOMPSimpleStore(LValue LVal, RValue RVal, QualType RValTy,

clang/test/OpenMP/declare_target_codegen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ int bar() { return 1 + foo() + bar() + baz1() + baz2(); }
150150
int maini1() {
151151
int a;
152152
static long aa = 32 + bbb + ccc + fff + ggg;
153-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}maini1{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr noundef nonnull align {{[0-9]+}} dereferenceable({{[0-9]+}}) %{{.*}}, i64 {{.*}}, i64 {{.*}})
153+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}maini1{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
154154
#pragma omp target map(tofrom \
155155
: a, b)
156156
{
@@ -163,7 +163,7 @@ int maini1() {
163163

164164
int baz3() { return 2 + baz2(); }
165165
int baz2() {
166-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz2{{.*}}_l[[@LINE+1]](ptr {{.*}}, i64 {{.*}})
166+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz2{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
167167
#pragma omp target parallel
168168
++c;
169169
return 2 + baz3();
@@ -175,7 +175,7 @@ static __typeof(create) __t_create __attribute__((__weakref__("__create")));
175175

176176
int baz5() {
177177
bool a;
178-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz5{{.*}}_l[[@LINE+1]](ptr {{.*}}, i64 {{.*}})
178+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz5{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
179179
#pragma omp target
180180
a = __extension__(void *) & __t_create != 0;
181181
return a;

clang/test/OpenMP/declare_target_link_codegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int maini1() {
5252
return 0;
5353
}
5454

55-
// DEVICE: define weak_odr protected void @__omp_offloading_{{.*}}_{{.*}}maini1{{.*}}_l44(ptr {{[^,]+}}, ptr noundef nonnull align {{[0-9]+}} dereferenceable{{[^,]*}}
55+
// DEVICE: define weak_odr protected void @__omp_offloading_{{.*}}_{{.*}}maini1{{.*}}_l44(ptr {{[^,]+}}, ptr {{[^,]*}}
5656
// DEVICE: [[C_REF:%.+]] = load ptr, ptr @c_decl_tgt_ref_ptr,
5757
// DEVICE: [[C:%.+]] = load i32, ptr [[C_REF]],
5858
// DEVICE: store i32 [[C]], ptr %

0 commit comments

Comments
 (0)