Skip to content

Commit 62eaebe

Browse files
authored
[SYCL-MLIR]: constructAttributeList returns CallingConv (#7339)
This PR modifies `constructAttributeList` to pass CallingConv as a reference argument (same as clang). The PR also applies a few clang-tidy related changes. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 3489e38 commit 62eaebe

File tree

6 files changed

+88
-99
lines changed

6 files changed

+88
-99
lines changed

polygeist/tools/cgeist/Lib/CGCall.cc

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,21 @@ static void castCallerArgs(mlir::func::FuncOp Callee,
100100
/******************************************************************************/
101101

102102
ValueCategory MLIRScanner::callHelper(
103-
mlir::func::FuncOp Tocall, QualType ObjType,
103+
mlir::func::FuncOp ToCall, QualType ObjType,
104104
ArrayRef<std::pair<ValueCategory, clang::Expr *>> Arguments,
105105
QualType RetType, bool RetReference, clang::Expr *Expr,
106-
const FunctionDecl *Callee) {
106+
const FunctionDecl &Callee) {
107107
SmallVector<mlir::Value, 4> Args;
108-
auto FnType = Tocall.getFunctionType();
109-
const clang::CodeGen::CGFunctionInfo &FI =
110-
Glob.GetOrCreateCGFunctionInfo(Callee);
111-
auto FIArgs = FI.arguments();
108+
mlir::FunctionType FnType = ToCall.getFunctionType();
109+
const clang::CodeGen::CGFunctionInfo &CalleeInfo =
110+
Glob.GetOrCreateCGFunctionInfo(&Callee);
111+
auto CalleeArgs = CalleeInfo.arguments();
112112

113113
size_t I = 0;
114114
// map from declaration name to mlir::value
115115
std::map<std::string, mlir::Value> MapFuncOperands;
116116

117-
for (auto Pair : Arguments) {
117+
for (const std::pair<ValueCategory, clang::Expr *> &Pair : Arguments) {
118118
ValueCategory Arg = std::get<0>(Pair);
119119
clang::Expr *A = std::get<1>(Pair);
120120

@@ -134,10 +134,10 @@ ValueCategory MLIRScanner::callHelper(
134134
if (I >= FnType.getInputs().size() || (I != 0 && A == nullptr)) {
135135
LLVM_DEBUG({
136136
Expr->dump();
137-
Tocall.dump();
137+
ToCall.dump();
138138
FnType.dump();
139-
for (auto a : Arguments)
140-
std::get<1>(a)->dump();
139+
for (auto A : Arguments)
140+
std::get<1>(A)->dump();
141141
});
142142
assert(false && "too many arguments in calls");
143143
}
@@ -175,21 +175,21 @@ ValueCategory MLIRScanner::callHelper(
175175
});
176176
assert(Arg.isReference);
177177

178-
auto Mt =
178+
auto MT =
179179
Glob.getTypes()
180180
.getMLIRType(
181181
Glob.getCGM().getContext().getLValueReferenceType(AType))
182182
.cast<MemRefType>();
183183

184184
LLVM_DEBUG({
185-
llvm::dbgs() << "mt: " << Mt << "\n";
185+
llvm::dbgs() << "MT: " << MT << "\n";
186186
llvm::dbgs() << "getLValueReferenceType(aType): "
187187
<< Glob.getCGM().getContext().getLValueReferenceType(
188188
AType)
189189
<< "\n";
190190
});
191191

192-
auto Shape = std::vector<int64_t>(Mt.getShape());
192+
auto Shape = std::vector<int64_t>(MT.getShape());
193193
assert(Shape.size() == 2);
194194

195195
auto Pshape = Shape[0];
@@ -199,21 +199,22 @@ ValueCategory MLIRScanner::callHelper(
199199
OpBuilder Abuilder(builder.getContext());
200200
Abuilder.setInsertionPointToStart(allocationScope);
201201
auto Alloc = Abuilder.create<mlir::memref::AllocaOp>(
202-
loc, mlir::MemRefType::get(Shape, Mt.getElementType(),
202+
loc, mlir::MemRefType::get(Shape, MT.getElementType(),
203203
MemRefLayoutAttrInterface(),
204-
Mt.getMemorySpace()));
204+
MT.getMemorySpace()));
205205
ValueCategory(Alloc, /*isRef*/ true)
206206
.store(builder, Arg, /*isArray*/ IsArray);
207207
Shape[0] = Pshape;
208208
Val = builder.create<mlir::memref::CastOp>(
209209
loc,
210-
mlir::MemRefType::get(Shape, Mt.getElementType(),
210+
mlir::MemRefType::get(Shape, MT.getElementType(),
211211
MemRefLayoutAttrInterface(),
212-
Mt.getMemorySpace()),
212+
MT.getMemorySpace()),
213213
Alloc);
214214
} else {
215-
if (FIArgs[I].info.getKind() == clang::CodeGen::ABIArgInfo::Indirect ||
216-
FIArgs[I].info.getKind() ==
215+
if (CalleeArgs[I].info.getKind() ==
216+
clang::CodeGen::ABIArgInfo::Indirect ||
217+
CalleeArgs[I].info.getKind() ==
217218
clang::CodeGen::ABIArgInfo::IndirectAliased) {
218219
OpBuilder Abuilder(builder.getContext());
219220
Abuilder.setInsertionPointToStart(allocationScope);
@@ -226,20 +227,20 @@ ValueCategory MLIRScanner::callHelper(
226227
Val = Abuilder.create<mlir::memref::CastOp>(
227228
loc, mlir::MemRefType::get(-1, Arg.getValue(builder).getType()),
228229
Val);
229-
} else {
230+
} else
230231
Val = Abuilder.create<mlir::LLVM::AllocaOp>(
231232
loc, Ty, Abuilder.create<arith::ConstantIntOp>(loc, 1, 64), 0);
232-
}
233+
233234
ValueCategory(Val, /*isRef*/ true)
234235
.store(builder, Arg.getValue(builder));
235236
} else
236237
Val = Arg.getValue(builder);
237238

238239
if (Val.getType().isa<LLVM::LLVMPointerType>() &&
239-
ExpectedType.isa<MemRefType>()) {
240+
ExpectedType.isa<MemRefType>())
240241
Val = builder.create<polygeist::Pointer2MemrefOp>(loc, ExpectedType,
241242
Val);
242-
}
243+
243244
if (auto PrevTy = Val.getType().dyn_cast<mlir::IntegerType>()) {
244245
auto IPostTy = ExpectedType.cast<mlir::IntegerType>();
245246
if (PrevTy != IPostTy)
@@ -266,7 +267,7 @@ ValueCategory MLIRScanner::callHelper(
266267
}
267268

268269
// handle lowerto pragma.
269-
if (LTInfo.SymbolTable.count(Tocall.getName())) {
270+
if (LTInfo.SymbolTable.count(ToCall.getName())) {
270271
SmallVector<mlir::Value> InputOperands;
271272
SmallVector<mlir::Value> OutputOperands;
272273
for (StringRef Input : LTInfo.InputSymbol)
@@ -280,7 +281,7 @@ ValueCategory MLIRScanner::callHelper(
280281
InputOperands.append(Args);
281282

282283
return ValueCategory(mlirclang::replaceFuncByOperation(
283-
Tocall, LTInfo.SymbolTable[Tocall.getName()],
284+
ToCall, LTInfo.SymbolTable[ToCall.getName()],
284285
builder, InputOperands, OutputOperands)
285286
->getResult(0),
286287
/*isReference=*/false);
@@ -292,13 +293,13 @@ ValueCategory MLIRScanner::callHelper(
292293

293294
mlir::Value Alloc;
294295
if (IsArrayReturn) {
295-
auto Mt =
296+
auto MT =
296297
Glob.getTypes()
297298
.getMLIRType(
298299
Glob.getCGM().getContext().getLValueReferenceType(RetType))
299300
.cast<MemRefType>();
300301

301-
auto Shape = std::vector<int64_t>(Mt.getShape());
302+
auto Shape = std::vector<int64_t>(MT.getShape());
302303
assert(Shape.size() == 2);
303304

304305
auto Pshape = Shape[0];
@@ -308,14 +309,14 @@ ValueCategory MLIRScanner::callHelper(
308309
OpBuilder Abuilder(builder.getContext());
309310
Abuilder.setInsertionPointToStart(allocationScope);
310311
Alloc = Abuilder.create<mlir::memref::AllocaOp>(
311-
loc, mlir::MemRefType::get(Shape, Mt.getElementType(),
312+
loc, mlir::MemRefType::get(Shape, MT.getElementType(),
312313
MemRefLayoutAttrInterface(),
313-
Mt.getMemorySpace()));
314+
MT.getMemorySpace()));
314315
Shape[0] = Pshape;
315316
Alloc = builder.create<mlir::memref::CastOp>(
316317
loc,
317-
mlir::MemRefType::get(Shape, Mt.getElementType(),
318-
MemRefLayoutAttrInterface(), Mt.getMemorySpace()),
318+
mlir::MemRefType::get(Shape, MT.getElementType(),
319+
MemRefLayoutAttrInterface(), MT.getMemorySpace()),
319320
Alloc);
320321
Args.push_back(Alloc);
321322
}
@@ -391,19 +392,19 @@ ValueCategory MLIRScanner::callHelper(
391392
auto Oldpoint = builder.getInsertionPoint();
392393
auto *Oldblock = builder.getInsertionBlock();
393394
builder.setInsertionPointToStart(&Op.getRegion().front());
394-
builder.create<CallOp>(loc, Tocall, Args);
395+
builder.create<CallOp>(loc, ToCall, Args);
395396
builder.create<gpu::TerminatorOp>(loc);
396397
builder.setInsertionPoint(Oldblock, Oldpoint);
397398
return nullptr;
398399
}
399400

400401
// Try to rescue some mismatched types.
401-
castCallerArgs(Tocall, Args, builder);
402+
castCallerArgs(ToCall, Args, builder);
402403

403404
/// Try to emit SYCL operations before creating a CallOp
404405
mlir::Operation *Op = emitSYCLOps(Expr, Args);
405406
if (!Op)
406-
Op = builder.create<CallOp>(loc, Tocall, Args);
407+
Op = builder.create<CallOp>(loc, ToCall, Args);
407408

408409
if (IsArrayReturn) {
409410
// TODO remedy return
@@ -412,13 +413,12 @@ ValueCategory MLIRScanner::callHelper(
412413
assert(!RetReference);
413414
return ValueCategory(Alloc, /*isReference*/ true);
414415
}
415-
if (Op->getNumResults()) {
416+
417+
if (Op->getNumResults())
416418
return ValueCategory(Op->getResult(0),
417419
/*isReference*/ RetReference);
418-
}
420+
419421
return nullptr;
420-
llvm::errs() << "do not support indirecto call of " << Tocall << "\n";
421-
assert(0 && "no indirect");
422422
}
423423

424424
std::pair<ValueCategory, bool>
@@ -445,21 +445,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *Expr) {
445445
});
446446

447447
auto Loc = getMLIRLocation(Expr->getExprLoc());
448-
/*
449-
if (auto ic = dyn_cast<ImplicitCastExpr>(expr->getCallee()))
450-
if (auto sr = dyn_cast<DeclRefExpr>(ic->getSubExpr())) {
451-
if (sr->getDecl()->getIdentifier() &&
452-
sr->getDecl()->getName() == "__shfl_up_sync") {
453-
std::vector<mlir::Value> args;
454-
for (auto a : expr->arguments()) {
455-
args.push_back(Visit(a).getValue(builder));
456-
}
457-
builder.create<gpu::ShuffleOp>(loc, );
458-
assert(0 && "__shfl_up_sync unhandled");
459-
return nullptr;
460-
}
461-
}
462-
*/
463448

464449
auto ValEmitted = emitGPUCallExpr(Expr);
465450
if (ValEmitted.second)
@@ -498,7 +483,6 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *Expr) {
498483

499484
if (auto *Oc = dyn_cast<CXXMemberCallExpr>(Expr)) {
500485
if (auto *LHS = dyn_cast<CXXTypeidExpr>(Oc->getImplicitObjectArgument())) {
501-
Expr->getCallee()->dump();
502486
if (auto *Ic = dyn_cast<MemberExpr>(Expr->getCallee()))
503487
if (auto *Sr = dyn_cast<NamedDecl>(Ic->getMemberDecl())) {
504488
if (Sr->getIdentifier() && Sr->getName() == "name") {
@@ -946,8 +930,8 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *Expr) {
946930
(isa<CXXOperatorCallExpr>(Expr) &&
947931
cast<CXXOperatorCallExpr>(Expr)->getOperator() ==
948932
OO_GreaterGreater)) {
949-
const auto *Tocall = EmitCallee(Expr->getCallee());
950-
auto StrcmpF = Glob.GetOrCreateLLVMFunction(Tocall);
933+
const auto *ToCall = EmitCallee(Expr->getCallee());
934+
auto StrcmpF = Glob.GetOrCreateLLVMFunction(ToCall);
951935

952936
std::vector<mlir::Value> Args;
953937
std::vector<std::pair<mlir::Value, mlir::Value>> Ops;
@@ -1149,17 +1133,18 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *Expr) {
11491133
if (auto *CC = dyn_cast<CXXMemberCallExpr>(Expr)) {
11501134
ValueCategory Obj = Visit(CC->getImplicitObjectArgument());
11511135
ObjType = CC->getObjectType();
1152-
#ifdef DEBUG
1153-
if (!obj.val) {
1154-
function.dump();
1155-
llvm::errs() << " objval: " << obj.val << "\n";
1156-
expr->dump();
1157-
CC->getImplicitObjectArgument()->dump();
1158-
}
1159-
#endif
1160-
if (cast<MemberExpr>(CC->getCallee()->IgnoreParens())->isArrow()) {
1136+
LLVM_DEBUG({
1137+
if (!Obj.val) {
1138+
function.dump();
1139+
llvm::errs() << " objval: " << Obj.val << "\n";
1140+
Expr->dump();
1141+
CC->getImplicitObjectArgument()->dump();
1142+
}
1143+
});
1144+
1145+
if (cast<MemberExpr>(CC->getCallee()->IgnoreParens())->isArrow())
11611146
Obj = Obj.dereference(builder);
1162-
}
1147+
11631148
assert(Obj.val);
11641149
assert(Obj.isReference);
11651150
Args.emplace_back(std::make_pair(Obj, (clang::Expr *)nullptr));
@@ -1169,7 +1154,7 @@ ValueCategory MLIRScanner::VisitCallExpr(clang::CallExpr *Expr) {
11691154
Args.push_back(std::make_pair(Visit(A), A));
11701155

11711156
return callHelper(ToCall, ObjType, Args, Expr->getType(),
1172-
Expr->isLValue() || Expr->isXValue(), Expr, Callee);
1157+
Expr->isLValue() || Expr->isXValue(), Expr, *Callee);
11731158
}
11741159

11751160
std::pair<ValueCategory, bool>

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -862,19 +862,20 @@ ValueCategory MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons,
862862
ShouldEmit = true;
863863

864864
FunctionToEmit F(*ctorDecl, mlirclang::getInputContext(builder));
865-
auto tocall = cast<func::FuncOp>(Glob.GetOrCreateMLIRFunction(F, ShouldEmit));
865+
auto ToCall = cast<func::FuncOp>(Glob.GetOrCreateMLIRFunction(F, ShouldEmit));
866866

867-
SmallVector<std::pair<ValueCategory, clang::Expr *>> args;
868-
args.emplace_back(std::make_pair(obj, (clang::Expr *)nullptr));
869-
for (auto a : cons->arguments())
870-
args.push_back(std::make_pair(Visit(a), a));
871-
callHelper(tocall, innerType, args,
867+
SmallVector<std::pair<ValueCategory, clang::Expr *>> Args{{obj, nullptr}};
868+
Args.reserve(cons->getNumArgs() + 1);
869+
for (auto A : cons->arguments())
870+
Args.emplace_back(Visit(A), A);
871+
872+
callHelper(ToCall, innerType, Args,
872873
/*retType*/ Glob.getCGM().getContext().VoidTy, false, cons,
873-
ctorDecl);
874+
*ctorDecl);
874875

875-
if (Glob.getCGM().getContext().getAsArrayType(cons->getType())) {
876+
if (Glob.getCGM().getContext().getAsArrayType(cons->getType()))
876877
builder.setInsertionPoint(oldblock, oldpoint);
877-
}
878+
878879
return endobj;
879880
}
880881

polygeist/tools/cgeist/Lib/CodeGenTypes.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,16 +646,12 @@ CodeGenTypes::getFunctionType(const clang::CodeGen::CGFunctionInfo &FI,
646646
void CodeGenTypes::constructAttributeList(
647647
StringRef Name, const clang::CodeGen::CGFunctionInfo &FI,
648648
clang::CodeGen::CGCalleeInfo CalleeInfo, mlirclang::AttributeList &AttrList,
649-
bool AttrOnCallSite, bool IsThunk) {
649+
unsigned &CallingConv, bool AttrOnCallSite, bool IsThunk) {
650650
MLIRContext *Ctx = TheModule->getContext();
651651
mlirclang::AttrBuilder FuncAttrsBuilder(*Ctx);
652652
mlirclang::AttrBuilder RetAttrsBuilder(*Ctx);
653653

654-
unsigned CC = FI.getEffectiveCallingConvention();
655-
FuncAttrsBuilder.addAttribute(
656-
"llvm.cconv", mlir::LLVM::CConvAttr::get(
657-
Ctx, static_cast<mlir::LLVM::cconv::CConv>(CC)));
658-
654+
CallingConv = FI.getEffectiveCallingConvention();
659655
if (FI.isNoReturn())
660656
FuncAttrsBuilder.addPassThroughAttribute(llvm::Attribute::NoReturn);
661657
if (FI.isCmseNSCall())

polygeist/tools/cgeist/Lib/CodeGenTypes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ class CodeGenTypes {
7979
const clang::CodeGen::CGFunctionInfo &FI,
8080
clang::CodeGen::CGCalleeInfo CalleeInfo,
8181
mlirclang::AttributeList &AttrList,
82-
bool AttrOnCallSite, bool IsThunk);
82+
unsigned &CallingConv, bool AttrOnCallSite,
83+
bool IsThunk);
8384

8485
// TODO: Possibly create a SYCLTypeCache
8586
mlir::Type getMLIRType(clang::QualType QT, bool *ImplicitRef = nullptr,

0 commit comments

Comments
 (0)