Skip to content

Commit 9a5d681

Browse files
bjacobGroverkss
authored andcommitted
Revert "[mlir][arith] Delete mul ext canonicalizations (llvm#144844)"
This reverts commit eb694b2.
1 parent 4983830 commit 9a5d681

File tree

3 files changed

+143
-5
lines changed

3 files changed

+143
-5
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
273273
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
274274
(SelectOp $pred, $a, $c)>;
275275

276-
// select(pred, false, true) => not(pred)
276+
// select(pred, false, true) => not(pred)
277277
def SelectI1ToNot :
278278
Pat<(SelectOp $pred,
279279
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -361,6 +361,10 @@ def OrOfExtSI :
361361
// TruncIOp
362362
//===----------------------------------------------------------------------===//
363363

364+
def ValuesWithSameType :
365+
Constraint<
366+
CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;
367+
364368
def ValueWiderThan :
365369
Constraint<And<[
366370
CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">,
@@ -393,6 +397,28 @@ def TruncIShrSIToTrunciShrUI :
393397
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
394398
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
395399

400+
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
401+
def TruncIShrUIMulIToMulSIExtended :
402+
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
403+
(Arith_MulIOp:$mul
404+
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
405+
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
406+
(Arith_MulSIExtendedOp:$res__1 $x, $y),
407+
[(ValuesWithSameType $tr, $x, $y),
408+
(ValueWiderThan $mul, $x),
409+
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
410+
411+
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
412+
def TruncIShrUIMulIToMulUIExtended :
413+
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
414+
(Arith_MulIOp:$mul
415+
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
416+
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
417+
(Arith_MulUIExtendedOp:$res__1 $x, $y),
418+
[(ValuesWithSameType $tr, $x, $y),
419+
(ValueWiderThan $mul, $x),
420+
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
421+
396422
//===----------------------------------------------------------------------===//
397423
// TruncIOp
398424
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,9 +1507,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
15071507

15081508
void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15091509
MLIRContext *context) {
1510-
patterns
1511-
.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1512-
context);
1510+
patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1511+
TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1512+
context);
15131513
}
15141514

15151515
LogicalResult arith::TruncIOp::verify() {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
10001000

10011001

10021002
// CHECK-LABEL: @foldSubXX_tensor
1003-
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
1003+
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
10041004
// CHECK: %[[sub:.+]] = arith.subi
10051005
// CHECK: return %[[c0]], %[[sub]]
10061006
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2952,6 +2952,118 @@ func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
29522952
return %hi : i32
29532953
}
29542954

2955+
// CHECK-LABEL: @wideMulToMulSIExtended
2956+
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
2957+
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
2958+
// CHECK-NEXT: return %[[HIGH]] : i32
2959+
func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
2960+
%x = arith.extsi %a: i32 to i64
2961+
%y = arith.extsi %b: i32 to i64
2962+
%m = arith.muli %x, %y: i64
2963+
%c32 = arith.constant 32: i64
2964+
%sh = arith.shrui %m, %c32 : i64
2965+
%hi = arith.trunci %sh: i64 to i32
2966+
return %hi : i32
2967+
}
2968+
2969+
// CHECK-LABEL: @wideMulToMulSIExtendedVector
2970+
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
2971+
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
2972+
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
2973+
func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
2974+
%x = arith.extsi %a: vector<3xi32> to vector<3xi64>
2975+
%y = arith.extsi %b: vector<3xi32> to vector<3xi64>
2976+
%m = arith.muli %x, %y: vector<3xi64>
2977+
%c32 = arith.constant dense<32>: vector<3xi64>
2978+
%sh = arith.shrui %m, %c32 : vector<3xi64>
2979+
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
2980+
return %hi : vector<3xi32>
2981+
}
2982+
2983+
// CHECK-LABEL: @wideMulToMulUIExtended
2984+
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
2985+
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
2986+
// CHECK-NEXT: return %[[HIGH]] : i32
2987+
func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
2988+
%x = arith.extui %a: i32 to i64
2989+
%y = arith.extui %b: i32 to i64
2990+
%m = arith.muli %x, %y: i64
2991+
%c32 = arith.constant 32: i64
2992+
%sh = arith.shrui %m, %c32 : i64
2993+
%hi = arith.trunci %sh: i64 to i32
2994+
return %hi : i32
2995+
}
2996+
2997+
// CHECK-LABEL: @wideMulToMulUIExtendedVector
2998+
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
2999+
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
3000+
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
3001+
func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
3002+
%x = arith.extui %a: vector<3xi32> to vector<3xi64>
3003+
%y = arith.extui %b: vector<3xi32> to vector<3xi64>
3004+
%m = arith.muli %x, %y: vector<3xi64>
3005+
%c32 = arith.constant dense<32>: vector<3xi64>
3006+
%sh = arith.shrui %m, %c32 : vector<3xi64>
3007+
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
3008+
return %hi : vector<3xi32>
3009+
}
3010+
3011+
// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
3012+
// CHECK: arith.muli
3013+
// CHECK: arith.shrui
3014+
// CHECK: arith.trunci
3015+
func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
3016+
%x = arith.extsi %a: i32 to i64
3017+
%y = arith.extui %b: i32 to i64
3018+
%m = arith.muli %x, %y: i64
3019+
%c32 = arith.constant 32: i64
3020+
%sh = arith.shrui %m, %c32 : i64
3021+
%hi = arith.trunci %sh: i64 to i32
3022+
return %hi : i32
3023+
}
3024+
3025+
// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
3026+
// CHECK: arith.muli
3027+
// CHECK: arith.shrui
3028+
// CHECK: arith.trunci
3029+
func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
3030+
%x = arith.extsi %a: i16 to i64
3031+
%y = arith.extsi %b: i16 to i64
3032+
%m = arith.muli %x, %y: i64
3033+
%c32 = arith.constant 32: i64
3034+
%sh = arith.shrui %m, %c32 : i64
3035+
%hi = arith.trunci %sh: i64 to i32
3036+
return %hi : i32
3037+
}
3038+
3039+
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
3040+
// CHECK: arith.muli
3041+
// CHECK: arith.shrui
3042+
// CHECK: arith.trunci
3043+
func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
3044+
%x = arith.extsi %a: i32 to i64
3045+
%y = arith.extsi %b: i32 to i64
3046+
%m = arith.muli %x, %y: i64
3047+
%c33 = arith.constant 33: i64
3048+
%sh = arith.shrui %m, %c33 : i64
3049+
%hi = arith.trunci %sh: i64 to i32
3050+
return %hi : i32
3051+
}
3052+
3053+
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
3054+
// CHECK: arith.muli
3055+
// CHECK: arith.shrui
3056+
// CHECK: arith.trunci
3057+
func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
3058+
%x = arith.extsi %a: i32 to i64
3059+
%y = arith.extsi %b: i32 to i64
3060+
%m = arith.muli %x, %y: i64
3061+
%c31 = arith.constant 31: i64
3062+
%sh = arith.shrui %m, %c31 : i64
3063+
%hi = arith.trunci %sh: i64 to i32
3064+
return %hi : i32
3065+
}
3066+
29553067
// CHECK-LABEL: @foldShli0
29563068
// CHECK-SAME: (%[[ARG:.*]]: i64)
29573069
// CHECK: return %[[ARG]] : i64

0 commit comments

Comments
 (0)