-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Arith] expand-ops: Support mini/maxi #90575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Matthias Gehre (mgehre-amd) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90575.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index dd04a599655894..676747ff01d09d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -152,6 +152,23 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
}
};
+template <typename OpTy, arith::CmpIPredicate pred>
+struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const final {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Location loc = op.getLoc();
+ Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
+ return success();
+ }
+};
+
template <typename OpTy, arith::CmpFPredicate pred>
struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
public:
@@ -335,6 +352,10 @@ struct ArithExpandOpsPass
arith::CeilDivSIOp,
arith::CeilDivUIOp,
arith::FloorDivSIOp,
+ arith::MaxSIOp,
+ arith::MaxUIOp,
+ arith::MinSIOp,
+ arith::MinUIOp,
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
@@ -383,6 +404,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
patterns.add<
+ MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
+ MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
+ MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
+ MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 6bed93e4c969db..174eb468cc0041 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -262,3 +262,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
// CHECK-LABEL: @truncf_vector_f32
// CHECK-NOT: arith.truncf
+
+// -----
+
+func.func @maxsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @maxsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.minsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @minsi
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @maxui(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxui %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @maxui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
+
+// -----
+
+func.func @minui(%a: i32, %b: i32) -> i32 {
+ %result = arith.minui %a, %b : i32
+ return %result : i32
+}
+// CHECK-LABEL: func @minui
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RESULT]] : i32
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks good, but can you explain why we need to expand them in the first place? I'm not familiar with the LLVM lowering; it seems like LLVM has matching intrinsics: https://llvm.org/docs/LangRef.html#llvm-smin-intrinsic
Co-authored-by: Jakub Kuderski <[email protected]>
Sure: We don't target LLVM and our target dialect (emitc) has no representation for them. |
Expand
arith.minsi
,arith.minui
,arith.maxsi
,arith.maxui
intoarith.cmpi
andarith.select
.