Skip to content

Commit d7bc3b7

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add missing check to canonicalization of GenericOp that are identity ops.
The operantion is an identity if the values yielded by the operation is the argument of the basic block of that operation. Add this missing check. Differential Revision: https://reviews.llvm.org/D94819
1 parent ed0fd56 commit d7bc3b7

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2276,13 +2276,15 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
22762276
SmallVector<Value, 4> returnedArgs;
22772277
for (Value yieldVal : yieldOp.values()) {
22782278
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
2279-
if (!yieldArg)
2279+
if (!yieldArg || yieldArg.getOwner() != &body)
22802280
return failure();
22812281
unsigned argumentNumber = yieldArg.getArgNumber();
22822282
if (argumentNumber < numIndexArgs)
22832283
return failure();
22842284
returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
22852285
}
2286+
if (returnedArgs.size() != genericOp.getOperation()->getNumResults())
2287+
return failure();
22862288
rewriter.replaceOp(genericOp, returnedArgs);
22872289
return success();
22882290
}

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,56 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
615615
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
616616
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
617617
// CHECK: return %[[ARG1]], %[[ARG0]]
618+
619+
// -----
620+
621+
#map = affine_map<(d0, d1) -> (d0, d1)>
622+
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
623+
%c0 = constant 0 : index
624+
%c1 = constant 1 : index
625+
%cst = constant 1.000000e+00 : f32
626+
%0 = dim %arg0, %c0 : tensor<?x?xf32>
627+
%1 = dim %arg0, %c1 : tensor<?x?xf32>
628+
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
629+
br ^bb1(%cst : f32)
630+
631+
^bb1(%arg1 : f32):
632+
%3 = linalg.generic
633+
{indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
634+
ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
635+
^bb0(%arg2: f32, %arg3 : f32):
636+
linalg.yield %arg1 : f32
637+
} -> tensor<?x?xf32>
638+
return %3 : tensor<?x?xf32>
639+
}
640+
// CHECK-LABEL: func @keep_not_noop
641+
// CHECK: %[[RESULT:.+]] = linalg.generic
642+
// CHECK: return %[[RESULT]]
643+
644+
// -----
645+
646+
#map = affine_map<(d0, d1) -> (d0, d1)>
647+
func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
648+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
649+
%c0 = constant 0 : index
650+
%c1 = constant 1 : index
651+
%cst = constant 1.000000e+00 : f32
652+
%0 = dim %arg0, %c0 : tensor<?x?xf32>
653+
%1 = dim %arg0, %c1 : tensor<?x?xf32>
654+
%2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
655+
br ^bb1(%cst : f32)
656+
657+
^bb1(%arg2 : f32):
658+
%3:2 = linalg.generic
659+
{indexing_maps = [#map, #map, #map, #map],
660+
iterator_types = ["parallel", "parallel"]}
661+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
662+
outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) {
663+
^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32):
664+
linalg.yield %arg2, %arg4 : f32, f32
665+
} -> tensor<?x?xf32>, tensor<?x?xf32>
666+
return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
667+
}
668+
// CHECK-LABEL: func @keep_not_noop
669+
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
670+
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1

0 commit comments

Comments
 (0)