Skip to content

Commit f05173d

Browse files
committed
Implement callee/caller type checking for llvm.call
This aligns the behavior with the standard call as well as the LLVM verifier. Reviewed By: ftynse, dcaballe Differential Revision: https://reviews.llvm.org/D88362
1 parent 22664a3 commit f05173d

File tree

3 files changed

+145
-6
lines changed

3 files changed

+145
-6
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,7 @@ def LLVM_CallOp : LLVM_Op<"call">,
417417
$_state.addAttributes(attributes);
418418
$_state.addOperands(operands);
419419
}]>];
420-
let verifier = [{
421-
if (getNumResults() > 1)
422-
return emitOpError("must have 0 or 1 result");
423-
return success();
424-
}];
420+
let verifier = [{ return ::verify(*this); }];
425421
let parser = [{ return parseCallOp(parser, result); }];
426422
let printer = [{ printCallOp(p, *this); }];
427423
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,83 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
531531
}
532532

533533
//===----------------------------------------------------------------------===//
534-
// Printing/parsing for LLVM::CallOp.
534+
// Verifying/Printing/parsing for LLVM::CallOp.
535535
//===----------------------------------------------------------------------===//
536536

537+
static LogicalResult verify(CallOp &op) {
538+
if (op.getNumResults() > 1)
539+
return op.emitOpError("must have 0 or 1 result");
540+
541+
// Type for the callee, we'll get it differently depending if it is a direct
542+
// or indirect call.
543+
LLVMType fnType;
544+
545+
bool isIndirect = false;
546+
547+
// If this is an indirect call, the callee attribute is missing.
548+
Optional<StringRef> calleeName = op.callee();
549+
if (!calleeName) {
550+
isIndirect = true;
551+
if (!op.getNumOperands())
552+
return op.emitOpError(
553+
"must have either a `callee` attribute or at least an operand");
554+
fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
555+
if (!fnType)
556+
return op.emitOpError("indirect call to a non-llvm type: ")
557+
<< op.getOperand(0).getType();
558+
auto ptrType = fnType.dyn_cast<LLVMPointerType>();
559+
if (!ptrType)
560+
return op.emitOpError("indirect call expects a pointer as callee: ")
561+
<< fnType;
562+
fnType = ptrType.getElementType();
563+
} else {
564+
Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
565+
if (!callee)
566+
return op.emitOpError()
567+
<< "'" << *calleeName
568+
<< "' does not reference a symbol in the current scope";
569+
auto fn = dyn_cast<LLVMFuncOp>(callee);
570+
if (!fn)
571+
return op.emitOpError() << "'" << *calleeName
572+
<< "' does not reference a valid LLVM function";
573+
574+
fnType = fn.getType();
575+
}
576+
if (!fnType.isFunctionTy())
577+
return op.emitOpError("callee does not have a functional type: ") << fnType;
578+
579+
// Verify that the operand and result types match the callee.
580+
581+
if (!fnType.isFunctionVarArg() &&
582+
fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
583+
return op.emitOpError()
584+
<< "incorrect number of operands ("
585+
<< (op.getNumOperands() - isIndirect)
586+
<< ") for callee (expecting: " << fnType.getFunctionNumParams()
587+
<< ")";
588+
589+
if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
590+
return op.emitOpError() << "incorrect number of operands ("
591+
<< (op.getNumOperands() - isIndirect)
592+
<< ") for varargs callee (expecting at least: "
593+
<< fnType.getFunctionNumParams() << ")";
594+
595+
for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
596+
if (op.getOperand(i + isIndirect).getType() !=
597+
fnType.getFunctionParamType(i))
598+
return op.emitOpError() << "operand type mismatch for operand " << i
599+
<< ": " << op.getOperand(i + isIndirect).getType()
600+
<< " != " << fnType.getFunctionParamType(i);
601+
602+
if (op.getNumResults() &&
603+
op.getResult(0).getType() != fnType.getFunctionResultType())
604+
return op.emitOpError()
605+
<< "result type mismatch: " << op.getResult(0).getType()
606+
<< " != " << fnType.getFunctionResultType();
607+
608+
return success();
609+
}
610+
537611
static void printCallOp(OpAsmPrinter &p, CallOp &op) {
538612
auto callee = op.callee();
539613
bool isDirect = callee.hasValue();

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,75 @@ func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
125125

126126
// -----
127127

128+
func @invalid_call() {
129+
// expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
130+
"llvm.call"() : () -> ()
131+
}
132+
133+
// -----
134+
135+
func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : !llvm.i8) {
136+
// expected-error@+1 {{expected function type}}
137+
llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
138+
}
139+
140+
// -----
141+
142+
func @call_unknown_symbol() {
143+
// expected-error@+1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}}
144+
llvm.call @missing_callee() : () -> ()
145+
}
146+
147+
// -----
148+
149+
func @standard_func_callee()
150+
151+
func @call_non_llvm() {
152+
// expected-error@+1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}}
153+
llvm.call @standard_func_callee() : () -> ()
154+
}
155+
156+
// -----
157+
158+
func @call_non_llvm_indirect(%arg0 : i32) {
159+
// expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}}
160+
"llvm.call"(%arg0) : (i32) -> ()
161+
}
162+
163+
// -----
164+
165+
llvm.func @callee_func(!llvm.i8) -> ()
166+
167+
func @callee_arg_mismatch(%arg0 : !llvm.i32) {
168+
// expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
169+
llvm.call @callee_func(%arg0) : (!llvm.i32) -> ()
170+
}
171+
172+
// -----
173+
174+
func @indirect_callee_arg_mismatch(%arg0 : !llvm.i32, %callee : !llvm.ptr<func<void(i8)>>) {
175+
// expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}}
176+
"llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, !llvm.i32) -> ()
177+
}
178+
179+
// -----
180+
181+
llvm.func @callee_func() -> (!llvm.i8)
182+
183+
func @callee_return_mismatch() {
184+
// expected-error@+1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
185+
%res = llvm.call @callee_func() : () -> (!llvm.i32)
186+
}
187+
188+
// -----
189+
190+
func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
191+
// expected-error@+1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}}
192+
"llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (!llvm.i32)
193+
}
194+
195+
// -----
196+
128197
func @call_too_many_results(%callee : () -> (i32,i32)) {
129198
// expected-error@+1 {{expected function with 0 or 1 result}}
130199
llvm.call %callee() : () -> (i32, i32)

0 commit comments

Comments
 (0)