Skip to content

Commit 4a3ea3d

Browse files
committed
[WIP][OpenMP] Remove dependency on libffi from offloading runtime
Summary: This patch attempts to remove the dependency on `libffi` by instead emitting the host / CPU kernels using an aggregate struct made from the captured context. This callows us to have a fixed function prototype we can call directly rather than requiring an extra library to decode the ABI to call a function with N (non variadic) arguments.
1 parent 81d20d8 commit 4a3ea3d

40 files changed

+6390
-4681
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5932,12 +5932,16 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
59325932

59335933
CodeGenFunction CGF(CGM, true);
59345934
llvm::OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
5935-
[&CGF, &D, &CodeGen](StringRef EntryFnName) {
5935+
[&CGF, &D, &CodeGen, this](StringRef EntryFnName) {
59365936
const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);
59375937

59385938
CGOpenMPTargetRegionInfo CGInfo(CS, CodeGen, EntryFnName);
59395939
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
5940-
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D.getBeginLoc());
5940+
if (CGM.getLangOpts().OpenMPIsTargetDevice && !isGPU())
5941+
return CGF.GenerateOpenMPCapturedStmtFunctionAggregate(
5942+
CS, D.getBeginLoc());
5943+
else
5944+
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D.getBeginLoc());
59415945
};
59425946

59435947
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,

clang/lib/CodeGen/CGStmtOpenMP.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,102 @@ static llvm::Function *emitOutlinedFunctionPrologue(
613613
return F;
614614
}
615615

616+
static llvm::Function *emitOutlinedFunctionPrologueAggregate(
617+
CodeGenFunction &CGF, FunctionArgList &Args,
618+
llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>>
619+
&LocalAddrs,
620+
llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>>
621+
&VLASizes,
622+
llvm::Value *&CXXThisValue, const CapturedStmt &CS, SourceLocation Loc,
623+
StringRef FunctionName) {
624+
const CapturedDecl *CD = CS.getCapturedDecl();
625+
const RecordDecl *RD = CS.getCapturedRecordDecl();
626+
627+
CXXThisValue = nullptr;
628+
// Build the argument list.
629+
CodeGenModule &CGM = CGF.CGM;
630+
ASTContext &Ctx = CGM.getContext();
631+
Args.append(CD->param_begin(), CD->param_end());
632+
633+
// Create the function declaration.
634+
const CGFunctionInfo &FuncInfo =
635+
CGM.getTypes().arrangeBuiltinFunctionDeclaration(Ctx.VoidTy, Args);
636+
llvm::FunctionType *FuncLLVMTy = CGM.getTypes().GetFunctionType(FuncInfo);
637+
638+
auto *F =
639+
llvm::Function::Create(FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
640+
FunctionName, &CGM.getModule());
641+
CGM.SetInternalFunctionAttributes(CD, F, FuncInfo);
642+
if (CD->isNothrow())
643+
F->setDoesNotThrow();
644+
F->setDoesNotRecurse();
645+
646+
// Generate the function.
647+
CGF.StartFunction(CD, Ctx.VoidTy, F, FuncInfo, Args, Loc, Loc);
648+
Address ContextAddr = CGF.GetAddrOfLocalVar(CD->getContextParam());
649+
llvm::Value *ContextV = CGF.Builder.CreateLoad(ContextAddr);
650+
LValue ContextLV = CGF.MakeNaturalAlignAddrLValue(
651+
ContextV, CGM.getContext().getTagDeclType(RD));
652+
auto I = CS.captures().begin();
653+
for (const FieldDecl *FD : RD->fields()) {
654+
LValue FieldLV = CGF.EmitLValueForFieldInitialization(ContextLV, FD);
655+
// Do not map arguments if we emit function with non-original types.
656+
Address LocalAddr = FieldLV.getAddress(CGF);
657+
// If we are capturing a pointer by copy we don't need to do anything, just
658+
// use the value that we get from the arguments.
659+
if (I->capturesVariableByCopy() && FD->getType()->isAnyPointerType()) {
660+
const VarDecl *CurVD = I->getCapturedVar();
661+
LocalAddrs.insert({FD, {CurVD, LocalAddr}});
662+
++I;
663+
continue;
664+
}
665+
666+
LValue ArgLVal =
667+
CGF.MakeAddrLValue(LocalAddr, FD->getType(), AlignmentSource::Decl);
668+
if (FD->hasCapturedVLAType()) {
669+
llvm::Value *ExprArg = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
670+
const VariableArrayType *VAT = FD->getCapturedVLAType();
671+
VLASizes.try_emplace(FD, VAT->getSizeExpr(), ExprArg);
672+
} else if (I->capturesVariable()) {
673+
const VarDecl *Var = I->getCapturedVar();
674+
QualType VarTy = Var->getType();
675+
Address ArgAddr = ArgLVal.getAddress(CGF);
676+
if (ArgLVal.getType()->isLValueReferenceType()) {
677+
ArgAddr = CGF.EmitLoadOfReference(ArgLVal);
678+
} else if (!VarTy->isVariablyModifiedType() || !VarTy->isPointerType()) {
679+
assert(ArgLVal.getType()->isPointerType());
680+
ArgAddr = CGF.EmitLoadOfPointer(
681+
ArgAddr, ArgLVal.getType()->castAs<PointerType>());
682+
}
683+
LocalAddrs.insert(
684+
{FD,
685+
{Var, Address(ArgAddr.getBasePointer(), ArgAddr.getElementType(),
686+
Ctx.getDeclAlign(Var))}});
687+
} else if (I->capturesVariableByCopy()) {
688+
assert(!FD->getType()->isAnyPointerType() &&
689+
"Not expecting a captured pointer.");
690+
const VarDecl *Var = I->getCapturedVar();
691+
Address CopyAddr = CGF.CreateMemTemp(FD->getType(), Ctx.getDeclAlign(FD),
692+
Var->getName());
693+
LValue CopyLVal =
694+
CGF.MakeAddrLValue(CopyAddr, FD->getType(), AlignmentSource::Decl);
695+
696+
RValue ArgRVal = CGF.EmitLoadOfLValue(ArgLVal, I->getLocation());
697+
CGF.EmitStoreThroughLValue(ArgRVal, CopyLVal);
698+
699+
LocalAddrs.insert({FD, {Var, CopyAddr}});
700+
} else {
701+
// If 'this' is captured, load it into CXXThisValue.
702+
assert(I->capturesThis());
703+
CXXThisValue = CGF.EmitLoadOfScalar(ArgLVal, I->getLocation());
704+
LocalAddrs.insert({FD, {nullptr, ArgLVal.getAddress(CGF)}});
705+
}
706+
++I;
707+
}
708+
709+
return F;
710+
}
711+
616712
llvm::Function *
617713
CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
618714
SourceLocation Loc) {
@@ -695,6 +791,36 @@ CodeGenFunction::GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
695791
return WrapperF;
696792
}
697793

794+
llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunctionAggregate(
795+
const CapturedStmt &S, SourceLocation Loc) {
796+
assert(
797+
CapturedStmtInfo &&
798+
"CapturedStmtInfo should be set when generating the captured function");
799+
const CapturedDecl *CD = S.getCapturedDecl();
800+
// Build the argument list.
801+
FunctionArgList Args;
802+
llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
803+
llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
804+
StringRef FunctionName = CapturedStmtInfo->getHelperName();
805+
llvm::Function *F = emitOutlinedFunctionPrologueAggregate(
806+
*this, Args, LocalAddrs, VLASizes, CXXThisValue, S, Loc, FunctionName);
807+
CodeGenFunction::OMPPrivateScope LocalScope(*this);
808+
for (const auto &LocalAddrPair : LocalAddrs) {
809+
if (LocalAddrPair.second.first) {
810+
LocalScope.addPrivate(LocalAddrPair.second.first,
811+
LocalAddrPair.second.second);
812+
}
813+
}
814+
(void)LocalScope.Privatize();
815+
for (const auto &VLASizePair : VLASizes)
816+
VLASizeMap[VLASizePair.second.first] = VLASizePair.second.second;
817+
PGO.assignRegionCounters(GlobalDecl(CD), F);
818+
CapturedStmtInfo->EmitBody(*this, CD->getBody());
819+
(void)LocalScope.ForceCleanup();
820+
FinishFunction(CD->getBodyRBrace());
821+
return F;
822+
}
823+
698824
//===----------------------------------------------------------------------===//
699825
// OpenMP Directive Emission
700826
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3642,6 +3642,9 @@ class CodeGenFunction : public CodeGenTypeCache {
36423642
Address GenerateCapturedStmtArgument(const CapturedStmt &S);
36433643
llvm::Function *GenerateOpenMPCapturedStmtFunction(const CapturedStmt &S,
36443644
SourceLocation Loc);
3645+
llvm::Function *
3646+
GenerateOpenMPCapturedStmtFunctionAggregate(const CapturedStmt &S,
3647+
SourceLocation Loc);
36453648
void GenerateOpenMPCapturedVars(const CapturedStmt &S,
36463649
SmallVectorImpl<llvm::Value *> &CapturedVars);
36473650
void emitOMPSimpleStore(LValue LVal, RValue RVal, QualType RValTy,

clang/test/OpenMP/declare_target_codegen.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ int bar() { return 1 + foo() + bar() + baz1() + baz2(); }
150150
int maini1() {
151151
int a;
152152
static long aa = 32 + bbb + ccc + fff + ggg;
153-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}maini1{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr noundef nonnull align {{[0-9]+}} dereferenceable({{[0-9]+}}) %{{.*}}, i64 {{.*}}, i64 {{.*}})
153+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}maini1{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
154154
#pragma omp target map(tofrom \
155155
: a, b)
156156
{
@@ -163,7 +163,7 @@ int maini1() {
163163

164164
int baz3() { return 2 + baz2(); }
165165
int baz2() {
166-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz2{{.*}}_l[[@LINE+1]](ptr {{.*}}, i64 {{.*}})
166+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz2{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
167167
#pragma omp target parallel
168168
++c;
169169
return 2 + baz3();
@@ -175,7 +175,7 @@ static __typeof(create) __t_create __attribute__((__weakref__("__create")));
175175

176176
int baz5() {
177177
bool a;
178-
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz5{{.*}}_l[[@LINE+1]](ptr {{.*}}, i64 {{.*}})
178+
// CHECK-DAG: define weak_odr protected void @__omp_offloading_{{.*}}baz5{{.*}}_l[[@LINE+1]](ptr {{.*}}, ptr {{.*}})
179179
#pragma omp target
180180
a = __extension__(void *) & __t_create != 0;
181181
return a;

clang/test/OpenMP/declare_target_link_codegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int maini1() {
5252
return 0;
5353
}
5454

55-
// DEVICE: define weak_odr protected void @__omp_offloading_{{.*}}_{{.*}}maini1{{.*}}_l44(ptr {{[^,]+}}, ptr noundef nonnull align {{[0-9]+}} dereferenceable{{[^,]*}}
55+
// DEVICE: define weak_odr protected void @__omp_offloading_{{.*}}_{{.*}}maini1{{.*}}_l44(ptr {{[^,]+}}, ptr {{[^,]*}}
5656
// DEVICE: [[C_REF:%.+]] = load ptr, ptr @c_decl_tgt_ref_ptr,
5757
// DEVICE: [[C:%.+]] = load i32, ptr [[C_REF]],
5858
// DEVICE: store i32 [[C]], ptr %

0 commit comments

Comments
 (0)