Skip to content

Commit 30badf9

Browse files
mgehre-amdkuhar
andauthored
[MLIR][Arith] expand-ops: Support mini/maxi (#90575)
Expand `arith.minsi`, `arith.minui`, `arith.maxsi`, `arith.maxui` into `arith.cmpi` and `arith.select`. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent dbe3766 commit 30badf9

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
152152
}
153153
};
154154

155+
template <typename OpTy, arith::CmpIPredicate pred>
156+
struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
157+
public:
158+
using OpRewritePattern<OpTy>::OpRewritePattern;
159+
160+
LogicalResult matchAndRewrite(OpTy op,
161+
PatternRewriter &rewriter) const final {
162+
Value lhs = op.getLhs();
163+
Value rhs = op.getRhs();
164+
165+
Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
166+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
167+
return success();
168+
}
169+
};
170+
155171
template <typename OpTy, arith::CmpFPredicate pred>
156172
struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
157173
public:
@@ -335,6 +351,10 @@ struct ArithExpandOpsPass
335351
arith::CeilDivSIOp,
336352
arith::CeilDivUIOp,
337353
arith::FloorDivSIOp,
354+
arith::MaxSIOp,
355+
arith::MaxUIOp,
356+
arith::MinSIOp,
357+
arith::MinUIOp,
338358
arith::MaximumFOp,
339359
arith::MinimumFOp,
340360
arith::MaxNumFOp,
@@ -383,6 +403,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
383403
populateCeilFloorDivExpandOpsPatterns(patterns);
384404
// clang-format off
385405
patterns.add<
406+
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
407+
MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
408+
MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
409+
MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
386410
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
387411
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
388412
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
262262

263263
// CHECK-LABEL: @truncf_vector_f32
264264
// CHECK-NOT: arith.truncf
265+
266+
// -----
267+
268+
func.func @maxsi(%a: i32, %b: i32) -> i32 {
269+
%result = arith.maxsi %a, %b : i32
270+
return %result : i32
271+
}
272+
// CHECK-LABEL: func @maxsi
273+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
274+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
275+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
276+
// CHECK-NEXT: return %[[RESULT]] : i32
277+
278+
// -----
279+
280+
func.func @minsi(%a: i32, %b: i32) -> i32 {
281+
%result = arith.minsi %a, %b : i32
282+
return %result : i32
283+
}
284+
// CHECK-LABEL: func @minsi
285+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
286+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
287+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
288+
// CHECK-NEXT: return %[[RESULT]] : i32
289+
290+
// -----
291+
292+
func.func @maxui(%a: i32, %b: i32) -> i32 {
293+
%result = arith.maxui %a, %b : i32
294+
return %result : i32
295+
}
296+
// CHECK-LABEL: func @maxui
297+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
298+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
299+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
300+
// CHECK-NEXT: return %[[RESULT]] : i32
301+
302+
// -----
303+
304+
func.func @minui(%a: i32, %b: i32) -> i32 {
305+
%result = arith.minui %a, %b : i32
306+
return %result : i32
307+
}
308+
// CHECK-LABEL: func @minui
309+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
310+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
311+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
312+
// CHECK-NEXT: return %[[RESULT]] : i32

0 commit comments

Comments
 (0)