Skip to content

Commit 681d59f

Browse files
author
Mogball
committed
[mlir][llvm] Improve error message when translating llvm.call_intrinsic (RELAND)
This is more user-friendly over an opaque crash. Reland after fixing bad rebase.
1 parent 141c4e7 commit 681d59f

File tree

5 files changed

+155
-70
lines changed

5 files changed

+155
-70
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,31 +1080,6 @@ def LLVM_vector_extract
10801080
}];
10811081
}
10821082

1083-
//===--------------------------------------------------------------------===//
1084-
// CallIntrinsicOp
1085-
//===--------------------------------------------------------------------===//
1086-
1087-
def LLVM_CallIntrinsicOp
1088-
: LLVM_Op<"call_intrinsic",
1089-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1090-
let summary = "Call to an LLVM intrinsic function.";
1091-
let description = [{
1092-
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
1093-
the MLIR function type of this op to determine which intrinsic to call.
1094-
}];
1095-
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
1096-
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1097-
"{}">:$fastmathFlags);
1098-
let results = (outs Variadic<LLVM_Type>:$results);
1099-
let llvmBuilder = [{
1100-
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
1101-
}];
1102-
let assemblyFormat = [{
1103-
$intrin `(` $args `)` `:` functional-type($args, $results)
1104-
custom<LLVMOpAttrs>(attr-dict)
1105-
}];
1106-
}
1107-
11081083
//
11091084
// LLVM Vector Predication operations.
11101085
//

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,4 +1759,30 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", []> {
17591759
}];
17601760
}
17611761

1762+
//===--------------------------------------------------------------------===//
1763+
// CallIntrinsicOp
1764+
//===--------------------------------------------------------------------===//
1765+
1766+
def LLVM_CallIntrinsicOp
1767+
: LLVM_Op<"call_intrinsic",
1768+
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1769+
let summary = "Call to an LLVM intrinsic function.";
1770+
let description = [{
1771+
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
1772+
the MLIR function type of this op to determine which intrinsic to call.
1773+
}];
1774+
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
1775+
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1776+
"{}">:$fastmathFlags);
1777+
let results = (outs Optional<LLVM_Type>:$results);
1778+
let llvmBuilder = [{
1779+
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
1780+
}];
1781+
let assemblyFormat = [{
1782+
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1783+
}];
1784+
1785+
let hasVerifier = 1;
1786+
}
1787+
17621788
#endif // LLVMIR_OPS

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// MLIR, and the LLVM IR dialect. It also registers the dialect.
1111
//
1212
//===----------------------------------------------------------------------===//
13+
1314
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1415
#include "LLVMInlining.h"
1516
#include "TypeDetail.h"
@@ -2785,6 +2786,16 @@ OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
27852786
return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
27862787
}
27872788

2789+
//===----------------------------------------------------------------------===//
2790+
// CallIntrinsicOp
2791+
//===----------------------------------------------------------------------===//
2792+
2793+
LogicalResult CallIntrinsicOp::verify() {
2794+
if (!getIntrin().startswith("llvm."))
2795+
return emitOpError() << "intrinsic name must start with 'llvm.'";
2796+
return success();
2797+
}
2798+
27882799
//===----------------------------------------------------------------------===//
27892800
// OpAsmDialectInterface
27902801
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,19 @@ static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) {
5858
return position;
5959
}
6060

61+
/// Convert an LLVM type to a string for printing in diagnostics.
62+
static std::string diagStr(const llvm::Type *type) {
63+
std::string str;
64+
llvm::raw_string_ostream os(str);
65+
type->print(os);
66+
return os.str();
67+
}
68+
6169
/// Get the declaration of an overloaded llvm intrinsic. First we get the
6270
/// overloaded argument types and/or result type from the CallIntrinsicOp, and
6371
/// then use those to get the correct declaration of the overloaded intrinsic.
6472
static FailureOr<llvm::Function *>
65-
getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
73+
getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
6674
llvm::Module *module,
6775
LLVM::ModuleTranslation &moduleTranslation) {
6876
SmallVector<llvm::Type *, 8> allArgTys;
@@ -86,7 +94,9 @@ getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
8694
if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
8795
overloadedArgTys) !=
8896
llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
89-
return op.emitOpError("intrinsic type is not a match");
97+
return mlir::emitError(op.getLoc(), "call intrinsic signature ")
98+
<< diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr()
99+
<< " does not match any of the overloads";
90100
}
91101

92102
ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
@@ -101,8 +111,8 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
101111
llvm::Intrinsic::ID id =
102112
llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
103113
if (!id)
104-
return op.emitOpError()
105-
<< "couldn't find intrinsic: " << op.getIntrinAttr();
114+
return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
115+
<< op.getIntrinAttr();
106116

107117
llvm::Function *fn = nullptr;
108118
if (llvm::Intrinsic::isOverloaded(id)) {
@@ -114,6 +124,44 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
114124
} else {
115125
fn = llvm::Intrinsic::getDeclaration(module, id, {});
116126
}
127+
128+
// Check the result type of the call.
129+
const llvm::Type *intrinType =
130+
op.getNumResults() == 0
131+
? llvm::Type::getVoidTy(module->getContext())
132+
: moduleTranslation.convertType(op.getResultTypes().front());
133+
if (intrinType != fn->getReturnType()) {
134+
return mlir::emitError(op.getLoc(), "intrinsic call returns ")
135+
<< diagStr(intrinType) << " but " << op.getIntrinAttr()
136+
<< " actually returns " << diagStr(fn->getReturnType());
137+
}
138+
139+
// Check the argument types of the call. If the function is variadic, check
140+
// the subrange of required arguments.
141+
if (!fn->getFunctionType()->isVarArg() &&
142+
op.getNumOperands() != fn->arg_size()) {
143+
return mlir::emitError(op.getLoc(), "intrinsic call has ")
144+
<< op.getNumOperands() << " operands but " << op.getIntrinAttr()
145+
<< " expects " << fn->arg_size();
146+
}
147+
if (fn->getFunctionType()->isVarArg() &&
148+
op.getNumOperands() < fn->arg_size()) {
149+
return mlir::emitError(op.getLoc(), "intrinsic call has ")
150+
<< op.getNumOperands() << " operands but variadic "
151+
<< op.getIntrinAttr() << " expects at least " << fn->arg_size();
152+
}
153+
// Check the arguments up to the number the function requires.
154+
for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
155+
const llvm::Type *expected = fn->getArg(i)->getType();
156+
const llvm::Type *actual =
157+
moduleTranslation.convertType(op.getOperandTypes()[i]);
158+
if (actual != expected) {
159+
return mlir::emitError(op.getLoc(), "intrinsic call operand #")
160+
<< i << " has type " << diagStr(actual) << " but "
161+
<< op.getIntrinAttr() << " expects " << diagStr(expected);
162+
}
163+
}
164+
117165
FastmathFlagsInterface itf = op;
118166
builder.setFastMathFlags(getFastmathFlags(itf));
119167

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,107 @@
11
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
22

3-
// CHECK: ; ModuleID = 'LLVMDialectModule'
4-
// CHECK: source_filename = "LLVMDialectModule"
5-
// CHECK: declare ptr @malloc(i64)
6-
// CHECK: declare void @free(ptr)
73
// CHECK: define <4 x float> @round_sse41() {
8-
// CHECK: %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, <4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, i32 1)
4+
// CHECK: %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> {{.*}}, <4 x float> {{.*}}, i32 1)
95
// CHECK: ret <4 x float> %1
106
// CHECK: }
117
llvm.func @round_sse41() -> vector<4xf32> {
12-
%0 = llvm.mlir.constant(1 : i32) : i32
13-
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
14-
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
15-
llvm.return %res: vector<4xf32>
8+
%0 = llvm.mlir.constant(1 : i32) : i32
9+
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
10+
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
11+
llvm.return %res: vector<4xf32>
1612
}
1713

1814
// -----
1915

20-
// CHECK: ; ModuleID = 'LLVMDialectModule'
21-
// CHECK: source_filename = "LLVMDialectModule"
22-
23-
// CHECK: declare ptr @malloc(i64)
24-
25-
// CHECK: declare void @free(ptr)
26-
2716
// CHECK: define float @round_overloaded() {
2817
// CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00)
2918
// CHECK: ret float %1
3019
// CHECK: }
3120
llvm.func @round_overloaded() -> f32 {
32-
%0 = llvm.mlir.constant(1.0 : f32) : f32
33-
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
34-
llvm.return %res: f32
21+
%0 = llvm.mlir.constant(1.0 : f32) : f32
22+
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
23+
llvm.return %res: f32
3524
}
3625

3726
// -----
3827

39-
// CHECK: ; ModuleID = 'LLVMDialectModule'
40-
// CHECK: source_filename = "LLVMDialectModule"
41-
// CHECK: declare ptr @malloc(i64)
42-
// CHECK: declare void @free(ptr)
4328
// CHECK: define void @lifetime_start() {
4429
// CHECK: %1 = alloca float, i8 1, align 4
4530
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1)
4631
// CHECK: ret void
4732
// CHECK: }
4833
llvm.func @lifetime_start() {
49-
%0 = llvm.mlir.constant(4 : i64) : i64
50-
%1 = llvm.mlir.constant(1 : i8) : i8
51-
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
52-
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
53-
llvm.return
34+
%0 = llvm.mlir.constant(4 : i64) : i64
35+
%1 = llvm.mlir.constant(1 : i8) : i8
36+
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
37+
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
38+
llvm.return
5439
}
5540

5641
// -----
5742

43+
// CHECK-LABEL: define void @variadic()
5844
llvm.func @variadic() {
59-
%0 = llvm.mlir.constant(1 : i8) : i8
60-
%1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
61-
llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
62-
llvm.return
45+
%0 = llvm.mlir.constant(1 : i8) : i8
46+
%1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
47+
// CHECK: call void (...) @llvm.localescape(ptr %1, ptr %1)
48+
llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
49+
llvm.return
6350
}
6451

6552
// -----
6653

6754
llvm.func @no_intrinsic() {
68-
// expected-error@below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}}
69-
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
70-
llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
71-
llvm.return
55+
// expected-error@below {{could not find LLVM intrinsic: "llvm.does_not_exist"}}
56+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
57+
llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
58+
llvm.return
7259
}
7360

7461
// -----
7562

7663
llvm.func @bad_types() {
77-
%0 = llvm.mlir.constant(1 : i8) : i8
78-
// expected-error@below {{'llvm.call_intrinsic' op intrinsic type is not a match}}
79-
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
80-
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
81-
llvm.return
64+
%0 = llvm.mlir.constant(1 : i8) : i8
65+
// expected-error@below {{call intrinsic signature i8 (i8) to overloaded intrinsic "llvm.round" does not match any of the overloads}}
66+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
67+
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
68+
llvm.return
69+
}
70+
71+
// -----
72+
73+
llvm.func @bad_result() {
74+
// expected-error @below {{intrinsic call returns void but "llvm.x86.sse41.round.ss" actually returns <4 x float>}}
75+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
76+
llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> ()
77+
llvm.return
78+
}
79+
80+
// -----
81+
82+
llvm.func @bad_result() {
83+
// expected-error @below {{intrinsic call returns <8 x float> but "llvm.x86.sse41.round.ss" actually returns <4 x float>}}
84+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
85+
llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<8xf32>)
86+
llvm.return
87+
}
88+
89+
// -----
90+
91+
llvm.func @bad_args() {
92+
// expected-error @below {{intrinsic call has 0 operands but "llvm.x86.sse41.round.ss" expects 3}}
93+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
94+
llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<4xf32>)
95+
llvm.return
96+
}
97+
98+
// -----
99+
100+
llvm.func @bad_args() {
101+
%0 = llvm.mlir.constant(1 : i64) : i64
102+
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
103+
// expected-error @below {{intrinsic call operand #2 has type i64 but "llvm.x86.sse41.round.ss" expects i32}}
104+
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
105+
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath<reassoc>}
106+
llvm.return
82107
}

0 commit comments

Comments
 (0)