Skip to content

Commit 07a029c

Browse files
not-jennirsuderman
authored andcommitted
Canonicalization for add to no-op if one of the inputs is zero
Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D113207
1 parent 795ff77 commit 07a029c

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ def Tosa_AddOp : Tosa_Op<"add", [
442442
let results = (outs
443443
Tosa_Tensor:$output
444444
);
445+
446+
let hasCanonicalizer = 1;
445447
}
446448

447449
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,55 @@ void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
289289
results.insert<NoOpOptimization>(context);
290290
}
291291

292+
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
293+
using OpRewritePattern::OpRewritePattern;
294+
295+
LogicalResult matchAndRewrite(tosa::AddOp op,
296+
PatternRewriter &rewriter) const override {
297+
auto input1 = op.input1();
298+
auto input2 = op.input2();
299+
300+
DenseElementsAttr input1Attr;
301+
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
302+
input2.getType() == op.getType()) {
303+
if (input1Attr.getType().getElementType().isa<FloatType>() &&
304+
input1Attr.getSplatValue<APFloat>().isZero()) {
305+
rewriter.replaceOp(op, op.input2());
306+
return success();
307+
}
308+
309+
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
310+
input1Attr.getSplatValue<APInt>().isZero()) {
311+
rewriter.replaceOp(op, op.input2());
312+
return success();
313+
}
314+
}
315+
316+
DenseElementsAttr input2Attr;
317+
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
318+
input1.getType() == op.getType()) {
319+
if (input2Attr.getType().getElementType().isa<FloatType>() &&
320+
input2Attr.getSplatValue<APFloat>().isZero()) {
321+
rewriter.replaceOp(op, op.input1());
322+
return success();
323+
}
324+
325+
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
326+
input2Attr.getSplatValue<APInt>().isZero()) {
327+
rewriter.replaceOp(op, op.input1());
328+
return success();
329+
}
330+
}
331+
332+
return failure();
333+
}
334+
};
335+
336+
void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
337+
MLIRContext *context) {
338+
results.insert<AddZeroOptimization>(context);
339+
}
340+
292341
//===----------------------------------------------------------------------===//
293342
// Operator Folders.
294343
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,38 @@ func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
99

1010
// -----
1111

12+
// CHECK-LABEL: @add_zero_different_shape
13+
func @add_zero_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
14+
// CHECK: tosa.add
15+
%zeros = "tosa.const"() {value = dense<0.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
16+
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
17+
return %1 : tensor<4x2x3xf32>
18+
}
19+
20+
// -----
21+
22+
// CHECK-LABEL: @add_zero_float
23+
func @add_zero_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
24+
// CHECK: return %arg0
25+
// CHECK-NOT: tosa.add
26+
%zeros = "tosa.const"() {value = dense<0.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
27+
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
28+
return %1 : tensor<2x3xf32>
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @add_zero_int
34+
func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
35+
// CHECK: return %arg0
36+
// CHECK-NOT: tosa.add
37+
%zeros = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
38+
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
39+
return %1 : tensor<2x3xi32>
40+
}
41+
42+
// -----
43+
1244
// CHECK-LABEL: @cast_fold
1345
func @cast_fold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
1446
// CHECK: return %arg0

0 commit comments

Comments
 (0)