Skip to content

Commit fbe91fe

Browse files
committed
[mlir][arith] Canonicalize addi(x, muli(y, -1)) -> subi(x, y)
These propagate all the way down to SPIR-V and result in some fishy code with large constants. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D145423
1 parent 260bae5 commit fbe91fe

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,27 @@ def AddISubConstantLHS :
4949
(ConstantLikeMatcher APIntAttr:$c1)),
5050
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>;
5151

52+
def IsScalarOrSplatNegativeOne :
53+
Constraint<And<[
54+
CPred<"succeeded(getIntOrSplatIntValue($0))">,
55+
CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>;
56+
57+
// addi(x, muli(y, -1)) -> subi(x, y)
58+
def AddIMulNegativeOneRhs :
59+
Pat<(Arith_AddIOp
60+
$x,
61+
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))),
62+
(Arith_SubIOp $x, $y),
63+
[(IsScalarOrSplatNegativeOne $c0)]>;
64+
65+
// addi(muli(x, -1), y) -> subi(y, x)
66+
def AddIMulNegativeOneLhs :
67+
Pat<(Arith_AddIOp
68+
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)),
69+
$y),
70+
(Arith_SubIOp $y, $x),
71+
[(IsScalarOrSplatNegativeOne $c0)]>;
72+
5273
//===----------------------------------------------------------------------===//
5374
// AddUIExtendedOp
5475
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
258258

259259
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
260260
MLIRContext *context) {
261-
patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
262-
context);
261+
patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
262+
AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
263263
}
264264

265265
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,72 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
735735
return %add : index
736736
}
737737

738+
// CHECK-LABEL: @addiMuliToSubiRhsI32
739+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
740+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
741+
// CHECK: return %[[SUB]]
742+
func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
743+
%c-1 = arith.constant -1 : i32
744+
%neg = arith.muli %arg1, %c-1 : i32
745+
%add = arith.addi %arg0, %neg : i32
746+
return %add : i32
747+
}
748+
749+
// CHECK-LABEL: @addiMuliToSubiRhsIndex
750+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
751+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
752+
// CHECK: return %[[SUB]]
753+
func.func @addiMuliToSubiRhsIndex(%arg0: index, %arg1: index) -> index {
754+
%c-1 = arith.constant -1 : index
755+
%neg = arith.muli %arg1, %c-1 : index
756+
%add = arith.addi %arg0, %neg : index
757+
return %add : index
758+
}
759+
760+
// CHECK-LABEL: @addiMuliToSubiRhsVector
761+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
762+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
763+
// CHECK: return %[[SUB]]
764+
func.func @addiMuliToSubiRhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
765+
%c-1 = arith.constant dense<-1> : vector<3xi64>
766+
%neg = arith.muli %arg1, %c-1 : vector<3xi64>
767+
%add = arith.addi %arg0, %neg : vector<3xi64>
768+
return %add : vector<3xi64>
769+
}
770+
771+
// CHECK-LABEL: @addiMuliToSubiLhsI32
772+
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
773+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
774+
// CHECK: return %[[SUB]]
775+
func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
776+
%c-1 = arith.constant -1 : i32
777+
%neg = arith.muli %arg1, %c-1 : i32
778+
%add = arith.addi %neg, %arg0 : i32
779+
return %add : i32
780+
}
781+
782+
// CHECK-LABEL: @addiMuliToSubiLhsIndex
783+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
784+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
785+
// CHECK: return %[[SUB]]
786+
func.func @addiMuliToSubiLhsIndex(%arg0: index, %arg1: index) -> index {
787+
%c-1 = arith.constant -1 : index
788+
%neg = arith.muli %arg1, %c-1 : index
789+
%add = arith.addi %neg, %arg0 : index
790+
return %add : index
791+
}
792+
793+
// CHECK-LABEL: @addiMuliToSubiLhsVector
794+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>)
795+
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64>
796+
// CHECK: return %[[SUB]]
797+
func.func @addiMuliToSubiLhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> {
798+
%c-1 = arith.constant dense<-1> : vector<3xi64>
799+
%neg = arith.muli %arg1, %c-1 : vector<3xi64>
800+
%add = arith.addi %neg, %arg0 : vector<3xi64>
801+
return %add : vector<3xi64>
802+
}
803+
738804
// CHECK-LABEL: @adduiExtendedZeroRhs
739805
// CHECK-NEXT: %[[false:.+]] = arith.constant false
740806
// CHECK-NEXT: return %arg0, %[[false]]

0 commit comments

Comments
 (0)