Skip to content

Commit 50b8a3c

Browse files
authored
[mlir][linalg] Refactor EraseIdentityGenericOp to be reused by other LinalgOps (#80466)
This refactored pattern rewrite is intended to be reused by any `LinalgOp`'s canonicalization pattern for removing identity ops. Additionally, this canonicalization has been applied to `BroadCastOp`.
1 parent d9c3066 commit 50b8a3c

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
531531

532532
let hasCustomAssemblyFormat = 1;
533533
let hasVerifier = 1;
534+
let hasCanonicalizer = 1;
534535
}
535536

536537
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,43 +1087,44 @@ LogicalResult GenericOp::verify() { return success(); }
10871087

10881088
namespace {
10891089

1090-
/// Remove generic operations (on tensors) that are just copying
1090+
/// Remove any linalg operation (on tensors) that are just copying
10911091
/// the values from inputs to the results. Requirements are
10921092
/// 1) All iterator types are parallel
10931093
/// 2) The body contains just a yield operation with the yielded values being
10941094
/// the arguments corresponding to the operands.
1095-
struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1096-
using OpRewritePattern<GenericOp>::OpRewritePattern;
1095+
template <typename OpTy>
1096+
struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1097+
using OpRewritePattern<OpTy>::OpRewritePattern;
10971098

1098-
LogicalResult matchAndRewrite(GenericOp genericOp,
1099+
LogicalResult matchAndRewrite(OpTy linalgOp,
10991100
PatternRewriter &rewriter) const override {
11001101
// Check all indexing maps are identity.
1101-
if (llvm::any_of(genericOp.getIndexingMapsArray(),
1102+
if (llvm::any_of(linalgOp.getIndexingMapsArray(),
11021103
[](AffineMap map) { return !map.isIdentity(); }))
11031104
return failure();
11041105

11051106
// Check that the body of the linalg operation is just a linalg.yield
11061107
// operation.
1107-
Block &body = genericOp.getRegion().front();
1108+
Block &body = linalgOp->getRegion(0).front();
11081109
if (!llvm::hasSingleElement(body))
11091110
return failure();
11101111
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
11111112
if (!yieldOp)
11121113
return failure();
11131114

11141115
// In the buffer case, we need to check exact buffer equality.
1115-
if (genericOp.hasPureBufferSemantics()) {
1116-
if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
1117-
genericOp.getDpsInputOperand(0)->get() ==
1118-
genericOp.getDpsInitOperand(0)->get()) {
1119-
rewriter.eraseOp(genericOp);
1116+
if (linalgOp.hasPureBufferSemantics()) {
1117+
if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1118+
linalgOp.getDpsInputOperand(0)->get() ==
1119+
linalgOp.getDpsInitOperand(0)->get()) {
1120+
rewriter.eraseOp(linalgOp);
11201121
return success();
11211122
}
11221123
return failure();
11231124
}
11241125

11251126
// Mixed semantics is not supported yet.
1126-
if (!genericOp.hasPureTensorSemantics())
1127+
if (!linalgOp.hasPureTensorSemantics())
11271128
return failure();
11281129

11291130
// Get the argument number of the returned values. That is the operand
@@ -1134,8 +1135,8 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11341135
if (!yieldArg || yieldArg.getOwner() != &body)
11351136
return failure();
11361137
unsigned argumentNumber = yieldArg.getArgNumber();
1137-
Value returnedArg = genericOp->getOperand(argumentNumber);
1138-
Type resultType = genericOp->getResult(yieldVal.index()).getType();
1138+
Value returnedArg = linalgOp->getOperand(argumentNumber);
1139+
Type resultType = linalgOp->getResult(yieldVal.index()).getType();
11391140
// The input can have a different type than the result, e.g. a dynamic
11401141
// input dimension can be turned into a static output dimension.
11411142
Type returnType = returnedArg.getType();
@@ -1145,21 +1146,21 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11451146
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
11461147
sparse_tensor::getSparseTensorEncoding(resultType))
11471148
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1148-
genericOp.getLoc(), resultType, returnedArg);
1149+
linalgOp.getLoc(), resultType, returnedArg);
11491150
else {
11501151
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
11511152
resultType))
11521153
return failure();
11531154
returnedArg = rewriter.create<tensor::CastOp>(
1154-
genericOp.getLoc(), resultType, returnedArg);
1155+
linalgOp.getLoc(), resultType, returnedArg);
11551156
}
11561157
}
11571158
returnedArgs.push_back(returnedArg);
11581159
}
11591160

1160-
if (returnedArgs.size() != genericOp->getNumResults())
1161+
if (returnedArgs.size() != linalgOp->getNumResults())
11611162
return failure();
1162-
rewriter.replaceOp(genericOp, returnedArgs);
1163+
rewriter.replaceOp(linalgOp, returnedArgs);
11631164
return success();
11641165
}
11651166
};
@@ -1168,7 +1169,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11681169

11691170
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
11701171
MLIRContext *context) {
1171-
results.add<EraseIdentityGenericOp>(context);
1172+
results.add<EraseIdentityLinalgOp<GenericOp>>(context);
11721173
}
11731174

11741175
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
@@ -1907,6 +1908,11 @@ void BroadcastOp::getEffects(
19071908
getDpsInits());
19081909
}
19091910

1911+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1912+
MLIRContext *context) {
1913+
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
1914+
}
1915+
19101916
//===----------------------------------------------------------------------===//
19111917
// YieldOp
19121918
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,3 +1017,15 @@ func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tenso
10171017
%copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
10181018
return %copy : tensor<?x?xf32>
10191019
}
1020+
1021+
// -----
1022+
1023+
// CHECK-LABEL: func @broadcast_same_shape(
1024+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
1025+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
1026+
// CHECK-NOT: linalg.broadcast
1027+
// CHECK: return %[[ARG0]] : tensor<2x3xf32>
1028+
func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> {
1029+
%0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
1030+
return %0 : tensor<2x3xf32>
1031+
}

0 commit comments

Comments
 (0)