Skip to content

Commit ac54176

Browse files
authored
expand-ops: minsi/maxsi (and unsigned) (#171)
* expand-ops: minsi/maxsi (and unsigned)
1 parent 86e0d41 commit ac54176

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,23 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
160160
}
161161
};
162162

163+
template <typename OpTy, arith::CmpIPredicate pred>
164+
struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
165+
public:
166+
using OpRewritePattern<OpTy>::OpRewritePattern;
167+
168+
LogicalResult matchAndRewrite(OpTy op,
169+
PatternRewriter &rewriter) const final {
170+
Value lhs = op.getLhs();
171+
Value rhs = op.getRhs();
172+
173+
Location loc = op.getLoc();
174+
Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
175+
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
176+
return success();
177+
}
178+
};
179+
163180
template <typename OpTy, arith::CmpFPredicate pred>
164181
struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
165182
public:
@@ -344,6 +361,10 @@ struct ArithExpandOpsPass
344361
arith::CeilDivSIOp,
345362
arith::CeilDivUIOp,
346363
arith::FloorDivSIOp,
364+
arith::MaxSIOp,
365+
arith::MaxUIOp,
366+
arith::MinSIOp,
367+
arith::MinUIOp,
347368
arith::MaximumFOp,
348369
arith::MinimumFOp,
349370
arith::MaxNumFOp,
@@ -392,6 +413,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
392413
populateCeilFloorDivExpandOpsPatterns(patterns);
393414
// clang-format off
394415
patterns.add<
416+
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
417+
MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
418+
MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
419+
MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
395420
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
396421
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
397422
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
295295

296296
// CHECK-LABEL: @truncf_vector_f32
297297
// CHECK-NOT: arith.truncf
298+
299+
// -----
300+
301+
func.func @maxsi(%a: i32, %b: i32) -> i32 {
302+
%result = arith.maxsi %a, %b : i32
303+
return %result : i32
304+
}
305+
// CHECK-LABEL: func @maxsi
306+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
307+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
308+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
309+
// CHECK-NEXT: return %[[RESULT]] : i32
310+
311+
// -----
312+
313+
func.func @minsi(%a: i32, %b: i32) -> i32 {
314+
%result = arith.minsi %a, %b : i32
315+
return %result : i32
316+
}
317+
// CHECK-LABEL: func @minsi
318+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
319+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
320+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
321+
// CHECK-NEXT: return %[[RESULT]] : i32
322+
323+
// -----
324+
325+
func.func @maxui(%a: i32, %b: i32) -> i32 {
326+
%result = arith.maxui %a, %b : i32
327+
return %result : i32
328+
}
329+
// CHECK-LABEL: func @maxui
330+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
331+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
332+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
333+
// CHECK-NEXT: return %[[RESULT]] : i32
334+
335+
// -----
336+
337+
func.func @minui(%a: i32, %b: i32) -> i32 {
338+
%result = arith.minui %a, %b : i32
339+
return %result : i32
340+
}
341+
// CHECK-LABEL: func @minui
342+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
343+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
344+
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
345+
// CHECK-NEXT: return %[[RESULT]] : i32

0 commit comments

Comments
 (0)