Skip to content

Commit 795b4ef

Browse files
authored
[MLIR] Add canonicalizations to all eligible index binary ops (#114000)
Generalizes the following canonicalization pattern to all associative and commutative binary ops in the `index` dialect. ``` x = v + c1 y = x + c2 --> y = x + (c1 + c2) ``` This includes: - `AddOp` - `MulOp` - `MaxSOp` - `MaxUOp` - `MinSOp` - `MinUOp` - `AndOp` - `OrOp` - `XOrOp` The operation folding is implemented using the existing folders since `createAndFold` is used in the canonicalization.
1 parent 9d09c6f commit 795b4ef

File tree

3 files changed

+183
-23
lines changed

3 files changed

+183
-23
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def Index_MulOp : IndexBinaryOp<"mul", [Commutative, Pure]> {
9595
%c = index.mul %a, %b
9696
```
9797
}];
98+
99+
let hasCanonicalizeMethod = 1;
98100
}
99101

100102
//===----------------------------------------------------------------------===//
@@ -263,6 +265,8 @@ def Index_MaxSOp : IndexBinaryOp<"maxs", [Commutative, Pure]> {
263265
%c = index.maxs %a, %b
264266
```
265267
}];
268+
269+
let hasCanonicalizeMethod = 1;
266270
}
267271

268272
//===----------------------------------------------------------------------===//
@@ -283,6 +287,8 @@ def Index_MaxUOp : IndexBinaryOp<"maxu", [Commutative, Pure]> {
283287
%c = index.maxu %a, %b
284288
```
285289
}];
290+
291+
let hasCanonicalizeMethod = 1;
286292
}
287293

288294
//===----------------------------------------------------------------------===//
@@ -302,6 +308,8 @@ def Index_MinSOp : IndexBinaryOp<"mins", [Commutative, Pure]> {
302308
%c = index.mins %a, %b
303309
```
304310
}];
311+
312+
let hasCanonicalizeMethod = 1;
305313
}
306314

307315
//===----------------------------------------------------------------------===//
@@ -322,6 +330,8 @@ def Index_MinUOp : IndexBinaryOp<"minu", [Commutative, Pure]> {
322330
%c = index.minu %a, %b
323331
```
324332
}];
333+
334+
let hasCanonicalizeMethod = 1;
325335
}
326336

327337
//===----------------------------------------------------------------------===//
@@ -404,6 +414,8 @@ def Index_AndOp : IndexBinaryOp<"and", [Commutative, Pure]> {
404414
%c = index.and %a, %b
405415
```
406416
}];
417+
418+
let hasCanonicalizeMethod = 1;
407419
}
408420

409421
//===----------------------------------------------------------------------===//
@@ -423,6 +435,8 @@ def Index_OrOp : IndexBinaryOp<"or", [Commutative, Pure]> {
423435
%c = index.or %a, %b
424436
```
425437
}];
438+
439+
let hasCanonicalizeMethod = 1;
426440
}
427441

428442
//===----------------------------------------------------------------------===//
@@ -442,6 +456,8 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
442456
%c = index.xor %a, %b
443457
```
444458
}];
459+
460+
let hasCanonicalizeMethod = 1;
445461
}
446462

447463
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked(
118118
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119119
}
120120

121+
/// Helper for associative and commutative binary ops that can be transformed:
122+
/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123+
/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
124+
template <typename BinaryOp>
125+
LogicalResult
126+
canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op,
127+
PatternRewriter &rewriter) {
128+
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
129+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
130+
131+
auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
132+
if (!lhsOp)
133+
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
134+
135+
if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
136+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
137+
138+
Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
139+
lhsOp.getRhs());
140+
if (c.getDefiningOp<BinaryOp>())
141+
return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
142+
143+
rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
144+
return success();
145+
}
146+
121147
//===----------------------------------------------------------------------===//
122148
// AddOp
123149
//===----------------------------------------------------------------------===//
@@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
136162

137163
return {};
138164
}
139-
/// Canonicalize
140-
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141-
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
142-
IntegerAttr c1, c2;
143-
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
144-
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
145-
146-
auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
147-
if (!add)
148-
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
149-
150-
if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
151-
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
152-
153-
auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
154-
c1.getInt() + c2.getInt());
155-
auto newAdd =
156-
rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
157165

158-
rewriter.replaceOp(op, newAdd);
159-
return success();
166+
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
167+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
160168
}
161169

162170
//===----------------------------------------------------------------------===//
@@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
200208
return {};
201209
}
202210

211+
LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
212+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
213+
}
214+
203215
//===----------------------------------------------------------------------===//
204216
// DivSOp
205217
//===----------------------------------------------------------------------===//
@@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
352364
});
353365
}
354366

367+
LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
368+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
369+
}
370+
355371
//===----------------------------------------------------------------------===//
356372
// MaxUOp
357373
//===----------------------------------------------------------------------===//
@@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
363379
});
364380
}
365381

382+
LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
383+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
384+
}
385+
366386
//===----------------------------------------------------------------------===//
367387
// MinSOp
368388
//===----------------------------------------------------------------------===//
@@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
374394
});
375395
}
376396

397+
LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
398+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
399+
}
400+
377401
//===----------------------------------------------------------------------===//
378402
// MinUOp
379403
//===----------------------------------------------------------------------===//
@@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
385409
});
386410
}
387411

412+
LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
413+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
414+
}
415+
388416
//===----------------------------------------------------------------------===//
389417
// ShlOp
390418
//===----------------------------------------------------------------------===//
@@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
442470
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
443471
}
444472

473+
LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
474+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
475+
}
476+
445477
//===----------------------------------------------------------------------===//
446478
// OrOp
447479
//===----------------------------------------------------------------------===//
@@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
452484
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
453485
}
454486

487+
LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
488+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
489+
}
490+
455491
//===----------------------------------------------------------------------===//
456492
// XOrOp
457493
//===----------------------------------------------------------------------===//
@@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
462498
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
463499
}
464500

501+
LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
502+
return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
503+
}
504+
465505
//===----------------------------------------------------------------------===//
466506
// CastSOp
467507
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)