Skip to content

Commit 51bc82d

Browse files
committed
[mlir] Implement SymbolUserOpInterface in LLVM::CallOp
Avoid expensive calls to `SymbolTable::lookupNearestSymbolFrom` in verifier Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D131285
1 parent 5c16eeb commit 51bc82d

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
647647

648648
def LLVM_CallOp : LLVM_Op<"call",
649649
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
650-
DeclareOpInterfaceMethods<CallOpInterface>]> {
650+
DeclareOpInterfaceMethods<CallOpInterface>,
651+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
651652
let summary = "Call to an LLVM function.";
652653
let description = [{
653654

@@ -701,8 +702,8 @@ def LLVM_CallOp : LLVM_Op<"call",
701702
StringAttr::get($_builder.getContext(), callee), operands);
702703
}]>];
703704
let hasCustomAssemblyFormat = 1;
704-
let hasVerifier = 1;
705705
}
706+
706707
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
707708
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
708709
let results = (outs LLVM_Type:$res);

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ Operation::operand_range CallOp::getArgOperands() {
11601160
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
11611161
}
11621162

1163-
LogicalResult CallOp::verify() {
1163+
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
11641164
if (getNumResults() > 1)
11651165
return emitOpError("must have 0 or 1 result");
11661166

@@ -1184,7 +1184,7 @@ LogicalResult CallOp::verify() {
11841184
fnType = ptrType.getElementType();
11851185
} else {
11861186
Operation *callee =
1187-
SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
1187+
symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
11881188
if (!callee)
11891189
return emitOpError()
11901190
<< "'" << calleeName.getValue()

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,27 +191,31 @@ func.func @store_malformed_elem_type(%foo: !llvm.ptr, %bar: f32) {
191191
func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
192192
// expected-error@+1 {{expected function type}}
193193
llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
194+
llvm.return
194195
}
195196

196197
// -----
197198

198199
func.func @invalid_call() {
199200
// expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}}
200201
"llvm.call"() : () -> ()
202+
llvm.return
201203
}
202204

203205
// -----
204206

205207
func.func @call_non_function_type(%callee : !llvm.func<i8 (i8)>, %arg : i8) {
206208
// expected-error@+1 {{expected function type}}
207209
llvm.call %callee(%arg) : !llvm.func<i8 (i8)>
210+
llvm.return
208211
}
209212

210213
// -----
211214

212215
func.func @call_unknown_symbol() {
213216
// expected-error@+1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}}
214217
llvm.call @missing_callee() : () -> ()
218+
llvm.return
215219
}
216220

217221
// -----
@@ -221,13 +225,15 @@ func.func private @standard_func_callee()
221225
func.func @call_non_llvm() {
222226
// expected-error@+1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}}
223227
llvm.call @standard_func_callee() : () -> ()
228+
llvm.return
224229
}
225230

226231
// -----
227232

228233
func.func @call_non_llvm_indirect(%arg0 : tensor<*xi32>) {
229234
// expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type}}
230235
"llvm.call"(%arg0) : (tensor<*xi32>) -> ()
236+
llvm.return
231237
}
232238

233239
// -----
@@ -237,13 +243,15 @@ llvm.func @callee_func(i8) -> ()
237243
func.func @callee_arg_mismatch(%arg0 : i32) {
238244
// expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
239245
llvm.call @callee_func(%arg0) : (i32) -> ()
246+
llvm.return
240247
}
241248

242249
// -----
243250

244251
func.func @indirect_callee_arg_mismatch(%arg0 : i32, %callee : !llvm.ptr<func<void(i8)>>) {
245252
// expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: 'i32' != 'i8'}}
246253
"llvm.call"(%callee, %arg0) : (!llvm.ptr<func<void(i8)>>, i32) -> ()
254+
llvm.return
247255
}
248256

249257
// -----
@@ -253,34 +261,39 @@ llvm.func @callee_func() -> (i8)
253261
func.func @callee_return_mismatch() {
254262
// expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}}
255263
%res = llvm.call @callee_func() : () -> (i32)
264+
llvm.return
256265
}
257266

258267
// -----
259268

260269
func.func @indirect_callee_return_mismatch(%callee : !llvm.ptr<func<i8()>>) {
261270
// expected-error@+1 {{'llvm.call' op result type mismatch: 'i32' != 'i8'}}
262271
"llvm.call"(%callee) : (!llvm.ptr<func<i8()>>) -> (i32)
272+
llvm.return
263273
}
264274

265275
// -----
266276

267277
func.func @call_too_many_results(%callee : () -> (i32,i32)) {
268278
// expected-error@+1 {{expected function with 0 or 1 result}}
269279
llvm.call %callee() : () -> (i32, i32)
280+
llvm.return
270281
}
271282

272283
// -----
273284

274285
func.func @call_non_llvm_result(%callee : () -> (tensor<*xi32>)) {
275286
// expected-error@+1 {{expected result to have LLVM type}}
276287
llvm.call %callee() : () -> (tensor<*xi32>)
288+
llvm.return
277289
}
278290

279291
// -----
280292

281293
func.func @call_non_llvm_input(%callee : (tensor<*xi32>) -> (), %arg : tensor<*xi32>) {
282294
// expected-error@+1 {{expected LLVM types as inputs}}
283295
llvm.call %callee(%arg) : (tensor<*xi32>) -> ()
296+
llvm.return
284297
}
285298

286299
// -----

0 commit comments

Comments
 (0)