@@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked(
118
118
return IntegerAttr::get (IndexType::get (lhs.getContext ()), *result64);
119
119
}
120
120
121
+ // / Helper for associative and commutative binary ops that can be transformed:
122
+ // / `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123
+ // / where c1 and c2 are constants. It is expected that `tmp` will be folded.
124
+ template <typename BinaryOp>
125
+ LogicalResult
126
+ canonicalizeAssociativeCommutativeBinaryOp (BinaryOp op,
127
+ PatternRewriter &rewriter) {
128
+ if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant ()))
129
+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
130
+
131
+ auto lhsOp = op.getLhs ().template getDefiningOp <BinaryOp>();
132
+ if (!lhsOp)
133
+ return rewriter.notifyMatchFailure (op.getLoc (), " LHS is not the same BinaryOp" );
134
+
135
+ if (!mlir::matchPattern (lhsOp.getRhs (), mlir::m_Constant ()))
136
+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS of LHS op is not a constant" );
137
+
138
+ Value c = rewriter.createOrFold <BinaryOp>(op->getLoc (), op.getRhs (),
139
+ lhsOp.getRhs ());
140
+ if (c.getDefiningOp <BinaryOp>())
141
+ return rewriter.notifyMatchFailure (op.getLoc (), " new BinaryOp was not folded" );
142
+
143
+ rewriter.replaceOpWithNewOp <BinaryOp>(op, lhsOp.getLhs (), c);
144
+ return success ();
145
+ }
146
+
121
147
// ===----------------------------------------------------------------------===//
122
148
// AddOp
123
149
// ===----------------------------------------------------------------------===//
@@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
136
162
137
163
return {};
138
164
}
139
- // / Canonicalize
140
- // / ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141
- LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
142
- IntegerAttr c1, c2;
143
- if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant (&c1)))
144
- return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
145
-
146
- auto add = op.getLhs ().getDefiningOp <mlir::index::AddOp>();
147
- if (!add)
148
- return rewriter.notifyMatchFailure (op.getLoc (), " LHS is not a add" );
149
-
150
- if (!mlir::matchPattern (add.getRhs (), mlir::m_Constant (&c2)))
151
- return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
152
-
153
- auto c = rewriter.create <mlir::index::ConstantOp>(op->getLoc (),
154
- c1.getInt () + c2.getInt ());
155
- auto newAdd =
156
- rewriter.create <mlir::index::AddOp>(op->getLoc (), add.getLhs (), c);
157
165
158
- rewriter. replaceOp ( op, newAdd);
159
- return success ( );
166
+ LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
167
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter );
160
168
}
161
169
162
170
// ===----------------------------------------------------------------------===//
@@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
200
208
return {};
201
209
}
202
210
211
+ LogicalResult MulOp::canonicalize (MulOp op, PatternRewriter &rewriter) {
212
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
213
+ }
214
+
203
215
// ===----------------------------------------------------------------------===//
204
216
// DivSOp
205
217
// ===----------------------------------------------------------------------===//
@@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
352
364
});
353
365
}
354
366
367
+ LogicalResult MaxSOp::canonicalize (MaxSOp op, PatternRewriter &rewriter) {
368
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
369
+ }
370
+
355
371
// ===----------------------------------------------------------------------===//
356
372
// MaxUOp
357
373
// ===----------------------------------------------------------------------===//
@@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
363
379
});
364
380
}
365
381
382
+ LogicalResult MaxUOp::canonicalize (MaxUOp op, PatternRewriter &rewriter) {
383
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
384
+ }
385
+
366
386
// ===----------------------------------------------------------------------===//
367
387
// MinSOp
368
388
// ===----------------------------------------------------------------------===//
@@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
374
394
});
375
395
}
376
396
397
+ LogicalResult MinSOp::canonicalize (MinSOp op, PatternRewriter &rewriter) {
398
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
399
+ }
400
+
377
401
// ===----------------------------------------------------------------------===//
378
402
// MinUOp
379
403
// ===----------------------------------------------------------------------===//
@@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
385
409
});
386
410
}
387
411
412
+ LogicalResult MinUOp::canonicalize (MinUOp op, PatternRewriter &rewriter) {
413
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
414
+ }
415
+
388
416
// ===----------------------------------------------------------------------===//
389
417
// ShlOp
390
418
// ===----------------------------------------------------------------------===//
@@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
442
470
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
443
471
}
444
472
473
+ LogicalResult AndOp::canonicalize (AndOp op, PatternRewriter &rewriter) {
474
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
475
+ }
476
+
445
477
// ===----------------------------------------------------------------------===//
446
478
// OrOp
447
479
// ===----------------------------------------------------------------------===//
@@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
452
484
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
453
485
}
454
486
487
+ LogicalResult OrOp::canonicalize (OrOp op, PatternRewriter &rewriter) {
488
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
489
+ }
490
+
455
491
// ===----------------------------------------------------------------------===//
456
492
// XOrOp
457
493
// ===----------------------------------------------------------------------===//
@@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
462
498
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
463
499
}
464
500
501
+ LogicalResult XOrOp::canonicalize (XOrOp op, PatternRewriter &rewriter) {
502
+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
503
+ }
504
+
465
505
// ===----------------------------------------------------------------------===//
466
506
// CastSOp
467
507
// ===----------------------------------------------------------------------===//
0 commit comments