Skip to content

Commit 05494f3

Browse files
[MLIR][LLVM] Tail call support for inline asm op (#140826)
1 parent a2ce564 commit 05494f3

File tree

11 files changed

+56
-11
lines changed

11 files changed

+56
-11
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
758758
the LLVM function type that uses an explicit void type to model functions
759759
that do not return a value.
760760

761-
If this operatin has the `no_inline` attribute, then this specific function call
762-
will never be inlined. The opposite behavior will occur if the call has `always_inline`
763-
attribute. The `inline_hint` attribute indicates that it is desirable to inline
761+
If this operatin has the `no_inline` attribute, then this specific function call
762+
will never be inlined. The opposite behavior will occur if the call has `always_inline`
763+
attribute. The `inline_hint` attribute indicates that it is desirable to inline
764764
this function call.
765765

766766
Examples:
@@ -2298,13 +2298,17 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
22982298
written, or referenced.
22992299
Attempting to define or reference any symbol or any global behavior is
23002300
considered undefined behavior at this time.
2301+
If `tail_call_kind` is used, the operation behaves like the specified
2302+
tail call kind. The `musttail` kind it's not available for this operation,
2303+
since it isn't supported by LLVM's inline asm.
23012304
}];
23022305
let arguments = (
23032306
ins Variadic<LLVM_Type>:$operands,
23042307
StrAttr:$asm_string,
23052308
StrAttr:$constraints,
23062309
UnitAttr:$has_side_effects,
23072310
UnitAttr:$is_align_stack,
2311+
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$tail_call_kind,
23082312
OptionalAttr<
23092313
DefaultValuedAttr<AsmATTOrIntel, "AsmDialect::AD_ATT">>:$asm_dialect,
23102314
OptionalAttr<ArrayAttr>:$operand_attrs);
@@ -2314,6 +2318,7 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
23142318
let assemblyFormat = [{
23152319
(`has_side_effects` $has_side_effects^)?
23162320
(`is_align_stack` $is_align_stack^)?
2321+
(`tail_call_kind` `=` $tail_call_kind^)?
23172322
(`asm_dialect` `=` $asm_dialect^)?
23182323
(`operand_attrs` `=` $operand_attrs^)?
23192324
attr-dict
@@ -2326,6 +2331,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
23262331
return "elementtype";
23272332
}
23282333
}];
2334+
2335+
let hasVerifier = 1;
23292336
}
23302337

23312338
//===--------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
439439
op,
440440
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
441441
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
442-
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
442+
/*is_align_stack=*/false, LLVM::TailCallKind::None,
443+
/*asm_dialect=*/asmDialectAttr,
443444
/*operand_attrs=*/ArrayAttr());
444445
return success();
445446
}

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
571571
/*asm_string=*/asmStr,
572572
/*constraints=*/constraintStr,
573573
/*has_side_effects=*/true,
574-
/*is_align_stack=*/false,
574+
/*is_align_stack=*/false, LLVM::TailCallKind::None,
575575
/*asm_dialect=*/asmDialectAttr,
576576
/*operand_attrs=*/ArrayAttr());
577577
}

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
139139
/*asm_string=*/llvm::StringRef(ptxInstruction),
140140
/*constraints=*/registerConstraints.data(),
141141
/*has_side_effects=*/interfaceOp.hasSideEffect(),
142-
/*is_align_stack=*/false,
142+
/*is_align_stack=*/false, LLVM::TailCallKind::None,
143143
/*asm_dialect=*/asmDialectAttr,
144144
/*operand_attrs=*/ArrayAttr());
145145
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4042,6 +4042,21 @@ LogicalResult LLVM::masked_scatter::verify() {
40424042
return success();
40434043
}
40444044

4045+
//===----------------------------------------------------------------------===//
4046+
// InlineAsmOp
4047+
//===----------------------------------------------------------------------===//
4048+
4049+
LogicalResult InlineAsmOp::verify() {
4050+
if (!getTailCallKindAttr())
4051+
return success();
4052+
4053+
if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4054+
return emitOpError(
4055+
"tail call kind 'musttail' is not supported by this operation");
4056+
4057+
return success();
4058+
}
4059+
40454060
//===----------------------------------------------------------------------===//
40464061
// LLVMDialect initialization, type parsing, and registration.
40474062
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
4141
auto asmOp = b.create<LLVM::InlineAsmOp>(
4242
v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
4343
/*constraints=*/asmCstr, /*has_side_effects=*/false,
44-
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
44+
/*is_align_stack=*/false, LLVM::TailCallKind::None,
45+
/*asm_dialect=*/asmDialectAttr,
4546
/*operand_attrs=*/ArrayAttr());
4647
return asmOp.getResult(0);
4748
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/ADT/TypeSwitch.h"
2020
#include "llvm/IR/IRBuilder.h"
2121
#include "llvm/IR/InlineAsm.h"
22+
#include "llvm/IR/Instructions.h"
2223
#include "llvm/IR/MDBuilder.h"
2324
#include "llvm/IR/MatrixBuilder.h"
2425
#include "llvm/IR/Operator.h"
@@ -507,6 +508,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
507508
llvm::CallInst *inst = builder.CreateCall(
508509
inlineAsmInst,
509510
moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
511+
inst->setTailCallKind(convertTailCallKindToLLVM(
512+
inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
510513
if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
511514
llvm::AttributeList attrList;
512515
for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
22002200
builder.getStringAttr(asmI->getAsmString()),
22012201
builder.getStringAttr(asmI->getConstraintString()),
22022202
asmI->hasSideEffects(), asmI->isAlignStack(),
2203+
convertTailCallKindFromLLVM(callInst->getTailCallKind()),
22032204
AsmDialectAttr::get(
22042205
mlirModule.getContext(),
22052206
convertAsmDialectFromLLVM(asmI->getDialect())),

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,3 +1882,11 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
18821882
%0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
18831883
llvm.return %0 : !llvm.array<2 x f64>
18841884
}
1885+
1886+
// ----
1887+
1888+
llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
1889+
// expected-error@+1 {{op tail call kind 'musttail' is not supported}}
1890+
%8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
1891+
llvm.return
1892+
}

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,11 @@ define i32 @inlineasm(i32 %arg1) {
554554
define void @inlineasm2() {
555555
%p = alloca ptr, align 8
556556
; CHECK: {{.*}} = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
557-
; CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
558-
call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
557+
; CHECK-NEXT: llvm.inline_asm has_side_effects tail_call_kind = <tail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
558+
tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
559+
560+
; CHECK: llvm.inline_asm has_side_effects tail_call_kind = <notail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
561+
notail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
559562
ret void
560563
}
561564

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,8 +2081,14 @@ llvm.func @useInlineAsm(%arg0: i32, %arg1 : !llvm.ptr) {
20812081
// CHECK-NEXT: call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
20822082
%5 = llvm.inline_asm "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
20832083

2084-
// CHECK-NEXT: call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
2085-
%6 = llvm.inline_asm has_side_effects operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
2084+
// CHECK-NEXT: tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %1)
2085+
%6 = llvm.inline_asm has_side_effects tail_call_kind = <tail> operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" %arg1 : (!llvm.ptr) -> !llvm.void
2086+
2087+
// CHECK-NEXT: = call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
2088+
%7 = llvm.inline_asm tail_call_kind = <none> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
2089+
2090+
// CHECK-NEXT: notail call { i8, i8 } asm "foo", "=r,=r,r"(i32 {{.*}})
2091+
%8 = llvm.inline_asm tail_call_kind = <notail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
20862092

20872093
llvm.return
20882094
}

0 commit comments

Comments
 (0)