Skip to content

Commit 4a01079

Browse files
authored
Expose Tail Kind Call to MLIR (#98080)
I would like to mark a call op in LLVM dialect as Musttail. The calling convention attribute only exposes Tail, not Musttail. I noticed that the CallInst of LLVM has an additional field to specify the flavor of tail call kind. I bubbled this up to the LLVM dialect by adding another attribute that maps to LLVM::CallInst::TailCallKind.
1 parent 8d5ba75 commit 4a01079

File tree

10 files changed

+174
-7
lines changed

10 files changed

+174
-7
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,4 +1077,13 @@ def LLVM_PoisonAttr : LLVM_Attr<"Poison", "poison">;
10771077
/// Folded into from LLVM::ZeroOp.
10781078
def LLVM_ZeroAttr : LLVM_Attr<"Zero", "zero">;
10791079

1080+
//===----------------------------------------------------------------------===//
1081+
// TailCallKindAttr
1082+
//===----------------------------------------------------------------------===//
1083+
1084+
def TailCallKindAttr : LLVM_Attr<"TailCallKind", "tailcallkind"> {
1085+
let parameters = (ins "TailCallKind":$TailCallKind);
1086+
let assemblyFormat = "`<` $TailCallKind `>`";
1087+
}
1088+
10801089
#endif // LLVMIR_ATTRDEFS

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class TBAANodeAttr : public Attribute {
8989
// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
9090
// --gen-attr-* and --gen-attrdef-*.
9191
using cconv::CConv;
92+
using tailcallkind::TailCallKind;
9293
using linkage::Linkage;
9394
} // namespace LLVM
9495
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,35 @@ def CConv : DialectAttr<
279279
"::mlir::LLVM::CConvAttr::get($_builder.getContext(), $0)";
280280
}
281281

282+
//===----------------------------------------------------------------------===//
283+
// TailCallKind
284+
//===----------------------------------------------------------------------===//
285+
286+
def TailCallKindNone : LLVM_EnumAttrCase<"None", "none", "TCK_None", 0>;
287+
def TailCallKindTail : LLVM_EnumAttrCase<"Tail", "tail", "TCK_Tail", 1>;
288+
def TailCallKindMustTail : LLVM_EnumAttrCase<"MustTail", "musttail", "TCK_MustTail", 2>;
289+
def TailCallKindNoTailCall : LLVM_EnumAttrCase<"NoTail", "notail", "TCK_NoTail", 3>;
290+
291+
def TailCallKindEnum : LLVM_EnumAttr<
292+
"TailCallKind",
293+
"::llvm::CallInst::TailCallKind",
294+
"Tail Call Kind",
295+
[TailCallKindNone, TailCallKindNoTailCall,
296+
TailCallKindMustTail, TailCallKindTail]> {
297+
let cppNamespace = "::mlir::LLVM::tailcallkind";
298+
}
299+
300+
def TailCallKind : DialectAttr<
301+
LLVM_Dialect,
302+
CPred<"::llvm::isa<::mlir::LLVM::TailCallKindAttr>($_self)">,
303+
"LLVM Calling Convention specification"> {
304+
let storageType = "::mlir::LLVM::TailCallKindAttr";
305+
let returnType = "::mlir::LLVM::tailcallkind::TailCallKind";
306+
let convertFromStorage = "$_self.getTailCallKind()";
307+
let constBuilderCall =
308+
"::mlir::LLVM::TailCallKindAttr::get($_builder.getContext(), $0)";
309+
}
310+
282311
//===----------------------------------------------------------------------===//
283312
// DIEmissionKind
284313
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
650650
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
651651
"{}">:$fastmathFlags,
652652
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
653-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
653+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
654+
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
654655
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
655656
let arguments = !con(args, aliasAttrs);
656657
let results = (outs Optional<LLVM_Type>:$result);

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ using namespace mlir;
4444
using namespace mlir::LLVM;
4545
using mlir::LLVM::cconv::getMaxEnumValForCConv;
4646
using mlir::LLVM::linkage::getMaxEnumValForLinkage;
47+
using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
4748

4849
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
4950

@@ -197,6 +198,7 @@ struct EnumTraits {};
197198
REGISTER_ENUM_TYPE(Linkage);
198199
REGISTER_ENUM_TYPE(UnnamedAddr);
199200
REGISTER_ENUM_TYPE(CConv);
201+
REGISTER_ENUM_TYPE(TailCallKind);
200202
REGISTER_ENUM_TYPE(Visibility);
201203
} // namespace
202204

@@ -974,7 +976,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974976
build(builder, state, results,
975977
TypeAttr::get(getLLVMFuncType(builder.getContext(), results, args)),
976978
callee, args, /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
977-
/*CConv=*/nullptr,
979+
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
978980
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
979981
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
980982
}
@@ -997,7 +999,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
997999
build(builder, state, getCallOpResultTypes(calleeType),
9981000
TypeAttr::get(calleeType), callee, args, /*fastmathFlags=*/nullptr,
9991001
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
1000-
/*access_groups=*/nullptr,
1002+
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
10011003
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10021004
}
10031005

@@ -1006,7 +1008,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10061008
build(builder, state, getCallOpResultTypes(calleeType),
10071009
TypeAttr::get(calleeType), /*callee=*/nullptr, args,
10081010
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1009-
/*CConv=*/nullptr,
1011+
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10101012
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
10111013
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10121014
}
@@ -1017,7 +1019,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10171019
build(builder, state, getCallOpResultTypes(calleeType),
10181020
TypeAttr::get(calleeType), SymbolRefAttr::get(func), args,
10191021
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1020-
/*CConv=*/nullptr,
1022+
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
10211023
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
10221024
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10231025
}
@@ -1180,6 +1182,9 @@ void CallOp::print(OpAsmPrinter &p) {
11801182
if (getCConv() != LLVM::CConv::C)
11811183
p << stringifyCConv(getCConv()) << ' ';
11821184

1185+
if(getTailCallKind() != LLVM::TailCallKind::None)
1186+
p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1187+
11831188
// Print the direct callee if present as a function attribute, or an indirect
11841189
// callee (first operand) otherwise.
11851190
if (isDirect)
@@ -1194,7 +1199,8 @@ void CallOp::print(OpAsmPrinter &p) {
11941199
p << " vararg(" << calleeType << ")";
11951200

11961201
p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1197-
{getCConvAttrName(), "callee", "callee_type"});
1202+
{getCConvAttrName(), "callee", "callee_type",
1203+
getTailCallKindAttrName()});
11981204

11991205
p << " : ";
12001206
if (!isDirect)
@@ -1262,7 +1268,7 @@ static ParseResult parseOptionalCallFuncPtr(
12621268
return success();
12631269
}
12641270

1265-
// <operation> ::= `llvm.call` (cconv)? (function-id | ssa-use)
1271+
// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
12661272
// `(` ssa-use-list `)`
12671273
// ( `vararg(` var-arg-func-type `)` )?
12681274
// attribute-dict? `:` (type `,`)? function-type
@@ -1277,6 +1283,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
12771283
CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
12781284
parser, result, LLVM::CConv::C)));
12791285

1286+
result.addAttribute(
1287+
getTailCallKindAttrName(result.name),
1288+
TailCallKindAttr::get(parser.getContext(),
1289+
parseOptionalLLVMKeyword<TailCallKind>(
1290+
parser, result, LLVM::TailCallKind::None)));
1291+
12801292
// Parse a function pointer for indirect calls.
12811293
if (parseOptionalCallFuncPtr(parser, operands))
12821294
return failure();

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
218218
operandsRef.drop_front());
219219
}
220220
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
221+
call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
221222
moduleTranslation.setAccessGroupsMetadata(callOp, call);
222223
moduleTranslation.setAliasScopeMetadata(callOp, call);
223224
moduleTranslation.setTBAAMetadata(callOp, call);

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,6 +1468,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
14681468
callOp = builder.create<CallOp>(loc, funcTy, operands);
14691469
}
14701470
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1471+
callOp.setTailCallKind(
1472+
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
14711473
setFastmathFlagsAttr(inst, callOp);
14721474
if (!callInst->getType()->isVoidTy())
14731475
mapValue(inst, callOp.getResult());

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,41 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) {
673673
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
674674
llvm.return
675675
}
676+
677+
// CHECK: llvm.func @tail_call_target() -> i32
678+
llvm.func @tail_call_target() -> i32
679+
680+
// CHECK-LABEL: @test_none
681+
llvm.func @test_none() -> i32 {
682+
// CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
683+
%0 = llvm.call none @tail_call_target() : () -> i32
684+
llvm.return %0 : i32
685+
}
686+
687+
// CHECK-LABEL: @test_default
688+
llvm.func @test_default() -> i32 {
689+
// CHECK-NEXT: llvm.call @tail_call_target() : () -> i32
690+
%0 = llvm.call @tail_call_target() : () -> i32
691+
llvm.return %0 : i32
692+
}
693+
694+
// CHECK-LABEL: @test_musttail
695+
llvm.func @test_musttail() -> i32 {
696+
// CHECK-NEXT: llvm.call musttail @tail_call_target() : () -> i32
697+
%0 = llvm.call musttail @tail_call_target() : () -> i32
698+
llvm.return %0 : i32
699+
}
700+
701+
// CHECK-LABEL: @test_tail
702+
llvm.func @test_tail() -> i32 {
703+
// CHECK-NEXT: llvm.call tail @tail_call_target() : () -> i32
704+
%0 = llvm.call tail @tail_call_target() : () -> i32
705+
llvm.return %0 : i32
706+
}
707+
708+
// CHECK-LABEL: @test_notail
709+
llvm.func @test_notail() -> i32 {
710+
// CHECK-NEXT: llvm.call notail @tail_call_target() : () -> i32
711+
%0 = llvm.call notail @tail_call_target() : () -> i32
712+
llvm.return %0 : i32
713+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
// CHECK: declare i32 @foo()
4+
llvm.func @foo() -> i32
5+
6+
// CHECK-LABEL: @test_none
7+
llvm.func @test_none() -> i32 {
8+
// CHECK-NEXT: call i32 @foo()
9+
%0 = llvm.call none @foo() : () -> i32
10+
llvm.return %0 : i32
11+
}
12+
13+
// CHECK-LABEL: @test_default
14+
llvm.func @test_default() -> i32 {
15+
// CHECK-NEXT: call i32 @foo()
16+
%0 = llvm.call @foo() : () -> i32
17+
llvm.return %0 : i32
18+
}
19+
20+
// CHECK-LABEL: @test_musttail
21+
llvm.func @test_musttail() -> i32 {
22+
// CHECK-NEXT: musttail call i32 @foo()
23+
%0 = llvm.call musttail @foo() : () -> i32
24+
llvm.return %0 : i32
25+
}
26+
27+
// CHECK-LABEL: @test_tail
28+
llvm.func @test_tail() -> i32 {
29+
// CHECK-NEXT: tail call i32 @foo()
30+
%0 = llvm.call tail @foo() : () -> i32
31+
llvm.return %0 : i32
32+
}
33+
34+
// CHECK-LABEL: @test_notail
35+
llvm.func @test_notail() -> i32 {
36+
// CHECK-NEXT: notail call i32 @foo()
37+
%0 = llvm.call notail @foo() : () -> i32
38+
llvm.return %0 : i32
39+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK: llvm.func @tailkind()
4+
declare void @tailkind()
5+
6+
; CHECK-LABEL: @call_tailkind
7+
define void @call_tailkind() {
8+
; CHECK: llvm.call musttail @tailkind()
9+
musttail call void @tailkind()
10+
ret void
11+
}
12+
13+
; // -----
14+
15+
; CHECK: llvm.func @tailkind()
16+
declare void @tailkind()
17+
18+
; CHECK-LABEL: @call_tailkind
19+
define void @call_tailkind() {
20+
; CHECK: llvm.call tail @tailkind()
21+
tail call void @tailkind()
22+
ret void
23+
}
24+
25+
; // -----
26+
27+
; CHECK: llvm.func @tailkind()
28+
declare void @tailkind()
29+
30+
; CHECK-LABEL: @call_tailkind
31+
define void @call_tailkind() {
32+
; CHECK: llvm.call notail @tailkind()
33+
notail call void @tailkind()
34+
ret void
35+
}

0 commit comments

Comments
 (0)