Skip to content

Commit 7bdba95

Browse files
authored
[mlir][arith] Fix arith.select canonicalization patterns (#84685)
Because `arith.select` does not propagate poison of the second or third operand depending on the condition, some canonicalization patterns are currently incorrect. This patch removes these incorrect patterns, and adds a new pattern to fix the case of `i1` select with constants. Patterns that are removed: * select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y) * select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y) * select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y) * select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y) * arith.select %arg, %x, %y : i1 => and(%arg, %x) or and(!%arg, %y) Pattern that is added: * select(pred, false, true) => not(pred) for i1 The first two patterns are incorrect when `predB` is poison and `predA` is false, as a non-poison `y` gets compiled to `poison`. The next two patterns are incorrect when `predB` is poison and `predA` is true, as a non-poison `x` gets compiled to `poison`. The last pattern is incorrect as it propagates poison from all operands afer compilation.
1 parent 207e45f commit 7bdba95

File tree

3 files changed

+8
-141
lines changed

3 files changed

+8
-141
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,6 @@ def CmpIExtUI :
253253
// SelectOp
254254
//===----------------------------------------------------------------------===//
255255

256-
def GetScalarOrVectorTrueAttribute :
257-
NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
258-
259256
// select(not(pred), a, b) => select(pred, b, a)
260257
def SelectNotCond :
261258
Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -272,31 +269,12 @@ def RedundantSelectFalse :
272269
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
273270
(SelectOp $pred, $a, $c)>;
274271

275-
// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
276-
def SelectAndCond :
277-
Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
278-
(SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
279-
280-
// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
281-
def SelectAndNotCond :
282-
Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
283-
(SelectOp (Arith_AndIOp $predA,
284-
(Arith_XOrIOp $predB,
285-
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
286-
$x, $y)>;
287-
288-
// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
289-
def SelectOrCond :
290-
Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
291-
(SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
292-
293-
// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
294-
def SelectOrNotCond :
295-
Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
296-
(SelectOp (Arith_OrIOp $predA,
297-
(Arith_XOrIOp $predB,
298-
(Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
299-
$x, $y)>;
272+
// select(pred, false, true) => not(pred)
273+
def SelectI1ToNot :
274+
Pat<(SelectOp $pred,
275+
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
276+
(ConstantLikeMatcher ConstantAttr<I1Attr, "1">)),
277+
(Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))>;
300278

301279
//===----------------------------------------------------------------------===//
302280
// IndexCastOp

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
969969
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
970970
}
971971

972-
973972
//===----------------------------------------------------------------------===//
974973
// MaxSIOp
975974
//===----------------------------------------------------------------------===//
@@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
21732172
// SelectOp
21742173
//===----------------------------------------------------------------------===//
21752174

2176-
// Transforms a select of a boolean to arithmetic operations
2177-
//
2178-
// arith.select %arg, %x, %y : i1
2179-
//
2180-
// becomes
2181-
//
2182-
// and(%arg, %x) or and(!%arg, %y)
2183-
struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
2184-
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
2185-
2186-
LogicalResult matchAndRewrite(arith::SelectOp op,
2187-
PatternRewriter &rewriter) const override {
2188-
if (!op.getType().isInteger(1))
2189-
return failure();
2190-
2191-
Value falseConstant =
2192-
rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
2193-
Value notCondition = rewriter.create<arith::XOrIOp>(
2194-
op.getLoc(), op.getCondition(), falseConstant);
2195-
2196-
Value trueVal = rewriter.create<arith::AndIOp>(
2197-
op.getLoc(), op.getCondition(), op.getTrueValue());
2198-
Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
2199-
op.getFalseValue());
2200-
rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
2201-
return success();
2202-
}
2203-
};
2204-
22052175
// select %arg, %c1, %c0 => extui %arg
22062176
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
22072177
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
@@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
22382208

22392209
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
22402210
MLIRContext *context) {
2241-
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
2242-
SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
2243-
SelectNotCond, SelectToExtUI>(context);
2211+
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2212+
SelectI1ToNot, SelectToExtUI>(context);
22442213
}
22452214

22462215
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 {
116116
return %res : i1
117117
}
118118

119-
// CHECK-LABEL: @selToArith
120-
// CHECK-NEXT: %[[trueval:.+]] = arith.constant true
121-
// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
122-
// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
123-
// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
124-
// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
125-
// CHECK: return %[[res]]
126-
func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
127-
%res = arith.select %arg0, %arg1, %arg2 : i1
128-
return %res : i1
129-
}
130-
131119
// CHECK-LABEL: @redundantSelectTrue
132120
// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
133121
// CHECK-NEXT: return %[[res]]
@@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
160148
return %res1, %res2 : i32, i32
161149
}
162150

163-
// CHECK-LABEL: @selAndCond
164-
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
165-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
166-
// CHECK-NEXT: return %[[res]]
167-
func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
168-
%sel = arith.select %arg0, %arg2, %arg3 : i32
169-
%res = arith.select %arg1, %sel, %arg3 : i32
170-
return %res : i32
171-
}
172-
173-
// CHECK-LABEL: @selAndNotCond
174-
// CHECK-NEXT: %[[one:.+]] = arith.constant true
175-
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
176-
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
177-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
178-
// CHECK-NEXT: return %[[res]]
179-
func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
180-
%sel = arith.select %arg0, %arg2, %arg3 : i32
181-
%res = arith.select %arg1, %sel, %arg2 : i32
182-
return %res : i32
183-
}
184-
185-
// CHECK-LABEL: @selAndNotCondVec
186-
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
187-
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
188-
// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
189-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
190-
// CHECK-NEXT: return %[[res]]
191-
func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
192-
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
193-
%res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
194-
return %res : vector<4xi32>
195-
}
196-
197-
// CHECK-LABEL: @selOrCond
198-
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
199-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
200-
// CHECK-NEXT: return %[[res]]
201-
func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
202-
%sel = arith.select %arg0, %arg2, %arg3 : i32
203-
%res = arith.select %arg1, %arg2, %sel : i32
204-
return %res : i32
205-
}
206-
207-
// CHECK-LABEL: @selOrNotCond
208-
// CHECK-NEXT: %[[one:.+]] = arith.constant true
209-
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
210-
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
211-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
212-
// CHECK-NEXT: return %[[res]]
213-
func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
214-
%sel = arith.select %arg0, %arg2, %arg3 : i32
215-
%res = arith.select %arg1, %arg3, %sel : i32
216-
return %res : i32
217-
}
218-
219-
// CHECK-LABEL: @selOrNotCondVec
220-
// CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
221-
// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
222-
// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
223-
// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
224-
// CHECK-NEXT: return %[[res]]
225-
func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
226-
%sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
227-
%res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
228-
return %res : vector<4xi32>
229-
}
230-
231151
// Test case: Folding of comparisons with equal operands.
232152
// CHECK-LABEL: @cmpi_equal_operands
233153
// CHECK-DAG: %[[T:.*]] = arith.constant true

0 commit comments

Comments
 (0)