Skip to content

Commit eb694b2

Browse files
authored
[mlir][arith] Delete mul ext canonicalizations (#144844)
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ``` %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ``` _, %hi = arith.mului_extended %a, %b : i32 ``` This commit removes the faulty canonicalizations since they are not believed to be generally beneficial (c.f., the discussion of the alternative #144787 which fixes the canonicalizations).
1 parent 89efae9 commit eb694b2

File tree

3 files changed

+5
-143
lines changed

3 files changed

+5
-143
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
273273
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
274274
(SelectOp $pred, $a, $c)>;
275275

276-
// select(pred, false, true) => not(pred)
276+
// select(pred, false, true) => not(pred)
277277
def SelectI1ToNot :
278278
Pat<(SelectOp $pred,
279279
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -361,10 +361,6 @@ def OrOfExtSI :
361361
// TruncIOp
362362
//===----------------------------------------------------------------------===//
363363

364-
def ValuesWithSameType :
365-
Constraint<
366-
CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;
367-
368364
def ValueWiderThan :
369365
Constraint<And<[
370366
CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">,
@@ -397,28 +393,6 @@ def TruncIShrSIToTrunciShrUI :
397393
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
398394
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
399395

400-
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
401-
def TruncIShrUIMulIToMulSIExtended :
402-
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
403-
(Arith_MulIOp:$mul
404-
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
405-
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
406-
(Arith_MulSIExtendedOp:$res__1 $x, $y),
407-
[(ValuesWithSameType $tr, $x, $y),
408-
(ValueWiderThan $mul, $x),
409-
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
410-
411-
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
412-
def TruncIShrUIMulIToMulUIExtended :
413-
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
414-
(Arith_MulIOp:$mul
415-
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
416-
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
417-
(Arith_MulUIExtendedOp:$res__1 $x, $y),
418-
[(ValuesWithSameType $tr, $x, $y),
419-
(ValueWiderThan $mul, $x),
420-
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
421-
422396
//===----------------------------------------------------------------------===//
423397
// TruncIOp
424398
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,9 +1509,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
15091509

15101510
void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
15111511
MLIRContext *context) {
1512-
patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
1513-
TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
1514-
context);
1512+
patterns
1513+
.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1514+
context);
15151515
}
15161516

15171517
LogicalResult arith::TruncIOp::verify() {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
10001000

10011001

10021002
// CHECK-LABEL: @foldSubXX_tensor
1003-
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
1003+
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
10041004
// CHECK: %[[sub:.+]] = arith.subi
10051005
// CHECK: return %[[c0]], %[[sub]]
10061006
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2952,118 +2952,6 @@ func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
29522952
return %hi : i32
29532953
}
29542954

2955-
// CHECK-LABEL: @wideMulToMulSIExtended
2956-
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
2957-
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
2958-
// CHECK-NEXT: return %[[HIGH]] : i32
2959-
func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
2960-
%x = arith.extsi %a: i32 to i64
2961-
%y = arith.extsi %b: i32 to i64
2962-
%m = arith.muli %x, %y: i64
2963-
%c32 = arith.constant 32: i64
2964-
%sh = arith.shrui %m, %c32 : i64
2965-
%hi = arith.trunci %sh: i64 to i32
2966-
return %hi : i32
2967-
}
2968-
2969-
// CHECK-LABEL: @wideMulToMulSIExtendedVector
2970-
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
2971-
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
2972-
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
2973-
func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
2974-
%x = arith.extsi %a: vector<3xi32> to vector<3xi64>
2975-
%y = arith.extsi %b: vector<3xi32> to vector<3xi64>
2976-
%m = arith.muli %x, %y: vector<3xi64>
2977-
%c32 = arith.constant dense<32>: vector<3xi64>
2978-
%sh = arith.shrui %m, %c32 : vector<3xi64>
2979-
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
2980-
return %hi : vector<3xi32>
2981-
}
2982-
2983-
// CHECK-LABEL: @wideMulToMulUIExtended
2984-
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
2985-
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
2986-
// CHECK-NEXT: return %[[HIGH]] : i32
2987-
func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
2988-
%x = arith.extui %a: i32 to i64
2989-
%y = arith.extui %b: i32 to i64
2990-
%m = arith.muli %x, %y: i64
2991-
%c32 = arith.constant 32: i64
2992-
%sh = arith.shrui %m, %c32 : i64
2993-
%hi = arith.trunci %sh: i64 to i32
2994-
return %hi : i32
2995-
}
2996-
2997-
// CHECK-LABEL: @wideMulToMulUIExtendedVector
2998-
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
2999-
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
3000-
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
3001-
func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
3002-
%x = arith.extui %a: vector<3xi32> to vector<3xi64>
3003-
%y = arith.extui %b: vector<3xi32> to vector<3xi64>
3004-
%m = arith.muli %x, %y: vector<3xi64>
3005-
%c32 = arith.constant dense<32>: vector<3xi64>
3006-
%sh = arith.shrui %m, %c32 : vector<3xi64>
3007-
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
3008-
return %hi : vector<3xi32>
3009-
}
3010-
3011-
// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
3012-
// CHECK: arith.muli
3013-
// CHECK: arith.shrui
3014-
// CHECK: arith.trunci
3015-
func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
3016-
%x = arith.extsi %a: i32 to i64
3017-
%y = arith.extui %b: i32 to i64
3018-
%m = arith.muli %x, %y: i64
3019-
%c32 = arith.constant 32: i64
3020-
%sh = arith.shrui %m, %c32 : i64
3021-
%hi = arith.trunci %sh: i64 to i32
3022-
return %hi : i32
3023-
}
3024-
3025-
// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
3026-
// CHECK: arith.muli
3027-
// CHECK: arith.shrui
3028-
// CHECK: arith.trunci
3029-
func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
3030-
%x = arith.extsi %a: i16 to i64
3031-
%y = arith.extsi %b: i16 to i64
3032-
%m = arith.muli %x, %y: i64
3033-
%c32 = arith.constant 32: i64
3034-
%sh = arith.shrui %m, %c32 : i64
3035-
%hi = arith.trunci %sh: i64 to i32
3036-
return %hi : i32
3037-
}
3038-
3039-
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
3040-
// CHECK: arith.muli
3041-
// CHECK: arith.shrui
3042-
// CHECK: arith.trunci
3043-
func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
3044-
%x = arith.extsi %a: i32 to i64
3045-
%y = arith.extsi %b: i32 to i64
3046-
%m = arith.muli %x, %y: i64
3047-
%c33 = arith.constant 33: i64
3048-
%sh = arith.shrui %m, %c33 : i64
3049-
%hi = arith.trunci %sh: i64 to i32
3050-
return %hi : i32
3051-
}
3052-
3053-
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
3054-
// CHECK: arith.muli
3055-
// CHECK: arith.shrui
3056-
// CHECK: arith.trunci
3057-
func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
3058-
%x = arith.extsi %a: i32 to i64
3059-
%y = arith.extsi %b: i32 to i64
3060-
%m = arith.muli %x, %y: i64
3061-
%c31 = arith.constant 31: i64
3062-
%sh = arith.shrui %m, %c31 : i64
3063-
%hi = arith.trunci %sh: i64 to i32
3064-
return %hi : i32
3065-
}
3066-
30672955
// CHECK-LABEL: @foldShli0
30682956
// CHECK-SAME: (%[[ARG:.*]]: i64)
30692957
// CHECK: return %[[ARG]] : i64

0 commit comments

Comments
 (0)