Skip to content

Commit c1f8568

Browse files
imaihalbondhugula
authored andcommitted
[MLIR] Fix for updating function signature in normalizing memrefs
Normalizing memrefs failed when a caller of symbolic use in a function can not be casted to `CallOp`. This patch avoids the failure by checking the result of the casting. If the caller can not be casted to `CallOp`, it is skipped. Differential Revision: https://reviews.llvm.org/D87746
1 parent 6caf3fb commit c1f8568

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

mlir/lib/Transforms/NormalizeMemRefs.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,23 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
263263
// type at the caller site.
264264
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
265265
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
266-
Operation *callOp = symbolUse.getUser();
267-
OpBuilder builder(callOp);
268-
StringRef callee = cast<CallOp>(callOp).getCallee();
266+
Operation *userOp = symbolUse.getUser();
267+
OpBuilder builder(userOp);
268+
// When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
269+
// that the non-CallOp has no memrefs to be replaced.
270+
// TODO: Handle cases where a non-CallOp symbol use of a function deals with
271+
// memrefs.
272+
auto callOp = dyn_cast<CallOp>(userOp);
273+
if (!callOp)
274+
continue;
275+
StringRef callee = callOp.getCallee();
269276
Operation *newCallOp = builder.create<CallOp>(
270-
callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
271-
callOp->getOperands());
277+
userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee),
278+
userOp->getOperands());
272279
bool replacingMemRefUsesFailed = false;
273280
bool returnTypeChanged = false;
274-
for (unsigned resIndex : llvm::seq<unsigned>(0, callOp->getNumResults())) {
275-
OpResult oldResult = callOp->getResult(resIndex);
281+
for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
282+
OpResult oldResult = userOp->getResult(resIndex);
276283
OpResult newResult = newCallOp->getResult(resIndex);
277284
// This condition ensures that if the result is not of type memref or if
278285
// the resulting memref was already having a trivial map layout then we
@@ -302,8 +309,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
302309
if (replacingMemRefUsesFailed)
303310
continue;
304311
// Replace all uses for other non-memref result types.
305-
callOp->replaceAllUsesWith(newCallOp);
306-
callOp->erase();
312+
userOp->replaceAllUsesWith(newCallOp);
313+
userOp->erase();
307314
if (returnTypeChanged) {
308315
// Since the return type changed it might lead to a change in function's
309316
// signature.

mlir/test/Transforms/normalize-memrefs-ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
8989
// CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32>
9090
return
9191
}
92+
93+
// Test with an arbitrary op that references the function symbol.
94+
95+
"test.op_funcref"() {func = @test_norm_mix} : () -> ()

mlir/test/lib/Dialect/Test/TestDialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Dialect/Traits.h"
1818
#include "mlir/IR/Dialect.h"
19+
#include "mlir/IR/Function.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/OpImplementation.h"
2122
#include "mlir/IR/RegionKindInterface.h"
@@ -29,7 +30,6 @@
2930

3031
#include "TestOpEnums.h.inc"
3132

32-
3333
#include "TestOpStructs.h.inc"
3434
#include "TestOpsDialect.h.inc"
3535

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,17 @@ def OpNonNorm : TEST_Op<"op_nonnorm"> {
629629
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
630630
}
631631

632+
// Test for memrefs normalization of an op with a reference to a function
633+
// symbol.
634+
def OpFuncRef : TEST_Op<"op_funcref"> {
635+
let summary = "Test op with a reference to a function symbol";
636+
let description = [{
637+
The "test.op_funcref" is a test op with a reference to a function symbol.
638+
}];
639+
let builders = [OpBuilder<[{OpBuilder &builder, OperationState &state,
640+
FuncOp function}]>];
641+
}
642+
632643
// Pattern add the argument plus a increasing static number hidden in
633644
// OpMTest function. That value is set into the optional argument.
634645
// That way, we will know if operations is called once or twice.

0 commit comments

Comments
 (0)