Skip to content

Commit 65fedb4

Browse files
authored
[MLIR] Add support for calling conventions to LLVM::CallOp and LLVM::InvokeOp (#71319)
Despite the fact that the LLVM dialect’s `FuncOp` already supports calling conventions, there was yet no support for them in the ops that actually perform function calls, which led to incorrect LLVM IR being generated if one actually tried setting a `FuncOp`’s calling convention to anything other than `ccc`. This commit adds support for calling conventions to `LLVM::CallOp` and `LLVM::InvokeOp` and makes sure that calling conventions are parsed, printed, and lowered appropriately.
1 parent e160130 commit 65fedb4

File tree

6 files changed

+181
-61
lines changed

6 files changed

+181
-61
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
521521
Variadic<LLVM_Type>:$callee_operands,
522522
Variadic<LLVM_Type>:$normalDestOperands,
523523
Variadic<LLVM_Type>:$unwindDestOperands,
524-
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
524+
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
525+
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
525526
let results = (outs Variadic<LLVM_Type>);
526527
let successors = (successor AnySuccessor:$normalDest,
527528
AnySuccessor:$unwindDest);
@@ -602,7 +603,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
602603
Variadic<LLVM_Type>:$callee_operands,
603604
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
604605
"{}">:$fastmathFlags,
605-
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
606+
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
607+
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
606608
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
607609
let arguments = !con(args, aliasAttrs);
608610
let results = (outs Optional<LLVM_Type>:$result);

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 83 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,52 @@ static Type getI1SameShape(Type type) {
9797
return i1Type;
9898
}
9999

100+
// Parses one of the keywords provided in the list `keywords` and returns the
101+
// position of the parsed keyword in the list. If none of the keywords from the
102+
// list is parsed, returns -1.
103+
static int parseOptionalKeywordAlternative(OpAsmParser &parser,
104+
ArrayRef<StringRef> keywords) {
105+
for (const auto &en : llvm::enumerate(keywords)) {
106+
if (succeeded(parser.parseOptionalKeyword(en.value())))
107+
return en.index();
108+
}
109+
return -1;
110+
}
111+
112+
namespace {
113+
template <typename Ty>
114+
struct EnumTraits {};
115+
116+
#define REGISTER_ENUM_TYPE(Ty) \
117+
template <> \
118+
struct EnumTraits<Ty> { \
119+
static StringRef stringify(Ty value) { return stringify##Ty(value); } \
120+
static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
121+
}
122+
123+
REGISTER_ENUM_TYPE(Linkage);
124+
REGISTER_ENUM_TYPE(UnnamedAddr);
125+
REGISTER_ENUM_TYPE(CConv);
126+
REGISTER_ENUM_TYPE(Visibility);
127+
} // namespace
128+
129+
/// Parse an enum from the keyword, or default to the provided default value.
130+
/// The return type is the enum type by default, unless overridden with the
131+
/// second template argument.
132+
template <typename EnumTy, typename RetTy = EnumTy>
133+
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
134+
OperationState &result,
135+
EnumTy defaultValue) {
136+
SmallVector<StringRef, 10> names;
137+
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
138+
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
139+
140+
int index = parseOptionalKeywordAlternative(parser, names);
141+
if (index == -1)
142+
return static_cast<RetTy>(defaultValue);
143+
return static_cast<RetTy>(index);
144+
}
145+
100146
//===----------------------------------------------------------------------===//
101147
// Printing, parsing, folding and builder for LLVM::CmpOp.
102148
//===----------------------------------------------------------------------===//
@@ -859,6 +905,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
859905
build(builder, state, results,
860906
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
861907
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
908+
/*CConv=*/nullptr,
862909
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
863910
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
864911
}
@@ -880,7 +927,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
880927
ValueRange args) {
881928
build(builder, state, getCallOpResultTypes(calleeType),
882929
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
883-
/*branch_weights=*/nullptr, /*access_groups=*/nullptr,
930+
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
931+
/*access_groups=*/nullptr,
884932
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
885933
}
886934

@@ -889,6 +937,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
889937
build(builder, state, getCallOpResultTypes(calleeType),
890938
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
891939
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
940+
/*CConv=*/nullptr,
892941
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
893942
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
894943
}
@@ -899,9 +948,11 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
899948
build(builder, state, getCallOpResultTypes(calleeType),
900949
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
901950
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
951+
/*CConv=*/nullptr,
902952
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
903953
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
904954
}
955+
905956
CallInterfaceCallable CallOp::getCallableForCallee() {
906957
// Direct call.
907958
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
@@ -1054,9 +1105,14 @@ void CallOp::print(OpAsmPrinter &p) {
10541105
isVarArg = calleeType.isVarArg();
10551106
}
10561107

1108+
p << ' ';
1109+
1110+
// Print calling convention.
1111+
if (getCConv() != LLVM::CConv::C)
1112+
p << stringifyCConv(getCConv()) << ' ';
1113+
10571114
// Print the direct callee if present as a function attribute, or an indirect
10581115
// callee (first operand) otherwise.
1059-
p << ' ';
10601116
if (isDirect)
10611117
p.printSymbolName(callee.value());
10621118
else
@@ -1069,7 +1125,7 @@ void CallOp::print(OpAsmPrinter &p) {
10691125
p << " vararg(" << calleeType << ")";
10701126

10711127
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1072-
{"callee", "callee_type"});
1128+
{getCConvAttrName(), "callee", "callee_type"});
10731129

10741130
p << " : ";
10751131
if (!isDirect)
@@ -1137,7 +1193,7 @@ static ParseResult parseOptionalCallFuncPtr(
11371193
return success();
11381194
}
11391195

1140-
// <operation> ::= `llvm.call` (function-id | ssa-use)
1196+
// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
11411197
// `(` ssa-use-list `)`
11421198
// ( `vararg(` var-arg-func-type `)` )?
11431199
// attribute-dict? `:` (type `,`)? function-type
@@ -1146,6 +1202,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
11461202
TypeAttr calleeType;
11471203
SmallVector<OpAsmParser::UnresolvedOperand> operands;
11481204

1205+
// Default to C Calling Convention if no keyword is provided.
1206+
result.addAttribute(
1207+
getCConvAttrName(result.name),
1208+
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1209+
parser, result, LLVM::CConv::C)));
1210+
11491211
// Parse a function pointer for indirect calls.
11501212
if (parseOptionalCallFuncPtr(parser, operands))
11511213
return failure();
@@ -1191,7 +1253,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
11911253
auto calleeType = func.getFunctionType();
11921254
build(builder, state, getCallOpResultTypes(calleeType),
11931255
TypeAttr::get(calleeType), SymbolRefAttr::get(func), ops, normalOps,
1194-
unwindOps, nullptr, normal, unwind);
1256+
unwindOps, nullptr, nullptr, normal, unwind);
11951257
}
11961258

11971259
void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
@@ -1200,7 +1262,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
12001262
ValueRange unwindOps) {
12011263
build(builder, state, tys,
12021264
TypeAttr::get(getLLVMFuncType(builder.getContext(), tys, ops)), callee,
1203-
ops, normalOps, unwindOps, nullptr, normal, unwind);
1265+
ops, normalOps, unwindOps, nullptr, nullptr, normal, unwind);
12041266
}
12051267

12061268
void InvokeOp::build(OpBuilder &builder, OperationState &state,
@@ -1209,7 +1271,7 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state,
12091271
Block *unwind, ValueRange unwindOps) {
12101272
build(builder, state, getCallOpResultTypes(calleeType),
12111273
TypeAttr::get(calleeType), callee, ops, normalOps, unwindOps, nullptr,
1212-
normal, unwind);
1274+
nullptr, normal, unwind);
12131275
}
12141276

12151277
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
@@ -1275,6 +1337,10 @@ void InvokeOp::print(OpAsmPrinter &p) {
12751337

12761338
p << ' ';
12771339

1340+
// Print calling convention.
1341+
if (getCConv() != LLVM::CConv::C)
1342+
p << stringifyCConv(getCConv()) << ' ';
1343+
12781344
// Either function name or pointer
12791345
if (isDirect)
12801346
p.printSymbolName(callee.value());
@@ -1290,9 +1356,9 @@ void InvokeOp::print(OpAsmPrinter &p) {
12901356
if (isVarArg)
12911357
p << " vararg(" << calleeType << ")";
12921358

1293-
p.printOptionalAttrDict(
1294-
(*this)->getAttrs(),
1295-
{InvokeOp::getOperandSegmentSizeAttr(), "callee", "callee_type"});
1359+
p.printOptionalAttrDict((*this)->getAttrs(),
1360+
{InvokeOp::getOperandSegmentSizeAttr(), "callee",
1361+
"callee_type", InvokeOp::getCConvAttrName()});
12961362

12971363
p << " : ";
12981364
if (!isDirect)
@@ -1301,7 +1367,7 @@ void InvokeOp::print(OpAsmPrinter &p) {
13011367
getResultTypes());
13021368
}
13031369

1304-
// <operation> ::= `llvm.invoke` (function-id | ssa-use)
1370+
// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
13051371
// `(` ssa-use-list `)`
13061372
// `to` bb-id (`[` ssa-use-and-type-list `]`)?
13071373
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
@@ -1315,6 +1381,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
13151381
SmallVector<Value, 4> normalOperands, unwindOperands;
13161382
Builder &builder = parser.getBuilder();
13171383

1384+
// Default to C Calling Convention if no keyword is provided.
1385+
result.addAttribute(
1386+
getCConvAttrName(result.name),
1387+
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1388+
parser, result, LLVM::CConv::C)));
1389+
13181390
// Parse a function pointer for indirect calls.
13191391
if (parseOptionalCallFuncPtr(parser, operands))
13201392
return failure();
@@ -1788,52 +1860,6 @@ void GlobalOp::print(OpAsmPrinter &p) {
17881860
}
17891861
}
17901862

1791-
// Parses one of the keywords provided in the list `keywords` and returns the
1792-
// position of the parsed keyword in the list. If none of the keywords from the
1793-
// list is parsed, returns -1.
1794-
static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1795-
ArrayRef<StringRef> keywords) {
1796-
for (const auto &en : llvm::enumerate(keywords)) {
1797-
if (succeeded(parser.parseOptionalKeyword(en.value())))
1798-
return en.index();
1799-
}
1800-
return -1;
1801-
}
1802-
1803-
namespace {
1804-
template <typename Ty>
1805-
struct EnumTraits {};
1806-
1807-
#define REGISTER_ENUM_TYPE(Ty) \
1808-
template <> \
1809-
struct EnumTraits<Ty> { \
1810-
static StringRef stringify(Ty value) { return stringify##Ty(value); } \
1811-
static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
1812-
}
1813-
1814-
REGISTER_ENUM_TYPE(Linkage);
1815-
REGISTER_ENUM_TYPE(UnnamedAddr);
1816-
REGISTER_ENUM_TYPE(CConv);
1817-
REGISTER_ENUM_TYPE(Visibility);
1818-
} // namespace
1819-
1820-
/// Parse an enum from the keyword, or default to the provided default value.
1821-
/// The return type is the enum type by default, unless overriden with the
1822-
/// second template argument.
1823-
template <typename EnumTy, typename RetTy = EnumTy>
1824-
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
1825-
OperationState &result,
1826-
EnumTy defaultValue) {
1827-
SmallVector<StringRef, 10> names;
1828-
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
1829-
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1830-
1831-
int index = parseOptionalKeywordAlternative(parser, names);
1832-
if (index == -1)
1833-
return static_cast<RetTy>(defaultValue);
1834-
return static_cast<RetTy>(index);
1835-
}
1836-
18371863
static LogicalResult verifyComdat(Operation *op,
18381864
std::optional<SymbolRefAttr> attr) {
18391865
if (!attr)

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
200200
call = builder.CreateCall(calleeType, operandsRef.front(),
201201
operandsRef.drop_front());
202202
}
203+
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
203204
moduleTranslation.setAccessGroupsMetadata(callOp, call);
204205
moduleTranslation.setAliasScopeMetadata(callOp, call);
205206
moduleTranslation.setTBAAMetadata(callOp, call);
@@ -275,7 +276,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
275276
if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
276277
auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
277278
ArrayRef<llvm::Value *> operandsRef(operands);
278-
llvm::Instruction *result;
279+
llvm::InvokeInst *result;
279280
if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
280281
result = builder.CreateInvoke(
281282
moduleTranslation.lookupFunction(attr.getValue()),
@@ -290,6 +291,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
290291
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
291292
operandsRef.drop_front());
292293
}
294+
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
293295
moduleTranslation.mapBranch(invOp, result);
294296
// InvokeOp can only have 0 or 1 result
295297
if (invOp->getNumResults() != 0) {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
llvm.func @__gxx_personality_v0(...) -> i32
4+
5+
// CHECK: declare fastcc void @cconv_fastcc()
6+
// CHECK: declare void @cconv_ccc()
7+
// CHECK: declare tailcc void @cconv_tailcc()
8+
// CHECK: declare ghccc void @cconv_ghccc()
9+
llvm.func fastcc @cconv_fastcc()
10+
llvm.func ccc @cconv_ccc()
11+
llvm.func tailcc @cconv_tailcc()
12+
llvm.func cc_10 @cconv_ghccc()
13+
14+
// CHECK-LABEL: @test_ccs
15+
llvm.func @test_ccs() {
16+
// CHECK-NEXT: call fastcc void @cconv_fastcc()
17+
// CHECK-NEXT: call void @cconv_ccc()
18+
// CHECK-NEXT: call void @cconv_ccc()
19+
// CHECK-NEXT: call tailcc void @cconv_tailcc()
20+
// CHECK-NEXT: call ghccc void @cconv_ghccc()
21+
// CHECK-NEXT: ret void
22+
llvm.call fastcc @cconv_fastcc() : () -> ()
23+
llvm.call ccc @cconv_ccc() : () -> ()
24+
llvm.call @cconv_ccc() : () -> ()
25+
llvm.call tailcc @cconv_tailcc() : () -> ()
26+
llvm.call cc_10 @cconv_ghccc() : () -> ()
27+
llvm.return
28+
}
29+
30+
// CHECK-LABEL: @test_ccs_invoke
31+
llvm.func @test_ccs_invoke() attributes { personality = @__gxx_personality_v0 } {
32+
// CHECK-NEXT: invoke fastcc void @cconv_fastcc()
33+
// CHECK-NEXT: to label %[[normal1:[0-9]+]] unwind label %[[unwind:[0-9]+]]
34+
llvm.invoke fastcc @cconv_fastcc() to ^bb1 unwind ^bb6 : () -> ()
35+
36+
^bb1:
37+
// CHECK: [[normal1]]:
38+
// CHECK-NEXT: invoke void @cconv_ccc()
39+
// CHECK-NEXT: to label %[[normal2:[0-9]+]] unwind label %[[unwind:[0-9]+]]
40+
llvm.invoke ccc @cconv_ccc() to ^bb2 unwind ^bb6 : () -> ()
41+
42+
^bb2:
43+
// CHECK: [[normal2]]:
44+
// CHECK-NEXT: invoke void @cconv_ccc()
45+
// CHECK-NEXT: to label %[[normal3:[0-9]+]] unwind label %[[unwind:[0-9]+]]
46+
llvm.invoke @cconv_ccc() to ^bb3 unwind ^bb6 : () -> ()
47+
48+
^bb3:
49+
// CHECK: [[normal3]]:
50+
// CHECK-NEXT: invoke tailcc void @cconv_tailcc()
51+
// CHECK-NEXT: to label %[[normal4:[0-9]+]] unwind label %[[unwind:[0-9]+]]
52+
llvm.invoke tailcc @cconv_tailcc() to ^bb4 unwind ^bb6 : () -> ()
53+
54+
^bb4:
55+
// CHECK: [[normal4]]:
56+
// CHECK-NEXT: invoke ghccc void @cconv_ghccc()
57+
// CHECK-NEXT: to label %[[normal5:[0-9]+]] unwind label %[[unwind:[0-9]+]]
58+
llvm.invoke cc_10 @cconv_ghccc() to ^bb5 unwind ^bb6 : () -> ()
59+
60+
^bb5:
61+
// CHECK: [[normal5]]:
62+
// CHECK-NEXT: ret void
63+
llvm.return
64+
65+
// CHECK: [[unwind]]:
66+
// CHECK-NEXT: landingpad { ptr, i32 }
67+
// CHECK-NEXT: cleanup
68+
// CHECK-NEXT: ret void
69+
^bb6:
70+
%0 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
71+
llvm.return
72+
}

0 commit comments

Comments
 (0)