-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][math] lower rsqrt to sqrt + fdiv #91344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This also turns out to be MathToLLVM's current behavior. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Corentin Ferry (cferry-AMD) ChangesThis commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)). Full diff: https://github.com/llvm/llvm-project/pull/91344.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 24e6d9a8d98e..ba6977251564 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -42,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
+void populateExpandRsqrtPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 5ccf3b6d72a2..05d32ad2bc3e 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -615,6 +615,24 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
return success();
}
+// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
+static LogicalResult convertRsqrtOp(math::RsqrtOp op,
+ PatternRewriter &rewriter) {
+
+ auto operand = op.getOperand();
+ auto operandTy = operand.getType();
+ auto eTy = getElementTypeOrSelf(operandTy);
+ if (!isa<FloatType>(eTy))
+ return failure();
+
+ Location loc = op->getLoc();
+ auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
+ auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
+ rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
+ ValueRange{constOneFloat, sqrtOp});
+ return success();
+}
+
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
@@ -678,3 +696,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundEvenOp);
}
+
+void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
+ patterns.add(convertRsqrtOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3d94b55126d0..d25f4e571e6a 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -658,3 +658,45 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
// CHECK: return %[[SEL]] : f32
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+// CHECK-SAME: -> f32
+// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
+// CHECK: return %[[DIV]] : f32
+func.func @rsqrt(%float: f32) -> (f32) {
+ %float_result = math.rsqrt %float : f32
+ return %float_result : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt_vec
+// CHECK-SAME: (%[[ARG:.*]]: vector<5xf32>)
+// CHECK-SAME: -> vector<5xf32>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<5xf32>
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : vector<5xf32>
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : vector<5xf32>
+// CHECK: return %[[DIV]] : vector<5xf32>
+func.func @rsqrt_vec(%float: vector<5xf32>) -> (vector<5xf32>) {
+ %float_result = math.rsqrt %float : vector<5xf32>
+ return %float_result : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rsqrt_tns
+// CHECK-SAME: (%[[ARG:.*]]: tensor<5x8xf32>)
+// CHECK-SAME: -> tensor<5x8xf32>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<5x8xf32>
+// CHECK-DAG: %[[SQRT:.*]] = math.sqrt %[[ARG]] : tensor<5x8xf32>
+// CHECK-DAG: %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : tensor<5x8xf32>
+// CHECK: return %[[DIV]] : tensor<5x8xf32>
+func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) {
+ %float_result = math.rsqrt %float : tensor<5x8xf32>
+ return %float_result : tensor<5x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index da48ccb6e5e0..69af2a08b97b 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -52,6 +52,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
+ populateExpandRsqrtPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
|
For review: @srcarroll @rsuderman @jinchen62 @pashu123, feel free to mention anyone else relevant! |
|
||
Location loc = op->getLoc(); | ||
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); | ||
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is preferred to use tablegen designated getter functions for operands and results (otherwise you need to check that you can even index getOperand(0)
.). so can do op.getOperand()
instead (see https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Math/IR/MathOps.td#L44)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just one nit. otherwise looks good to me. thanks
out of curiosity, why choose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add integration tests here: https://github.com/llvm/llvm-project/blob/main/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
auto operand = op.getOperand(); | ||
auto operandTy = operand.getType(); | ||
auto eTy = getElementTypeOrSelf(operandTy); | ||
if (!isa<FloatType>(eTy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the expansion will work fine for f16, f32 and f64. You can enable them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't FloatType
encompass all of those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, yes. I only saw the tests. Could you edit the tests and add the F16 or F64 ones, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure -- @srcarroll, you're right, F16 and F64 are indeed supported, let me add tests for them.
Location loc = op->getLoc(); | ||
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); | ||
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0)); | ||
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please include an integration test to verify UB. For example, the test should check the behaviour when the input value is 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added within test-expand-math-approx.mlir
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments.
I have actually not thought of that other way. Let me have a look... that's a good point! |
Shallow testing on 1/sqrt(x) for integer values of x did not point out significant accuracy differences. However, the behavior with respect to %1 = arith.constant -0.0 : f32
%one = arith.constant 1.0 : f32
// CHECK: -nan
%2 = arith.divf %one, %1 : f32
%3 = math.sqrt %2 : f32
vector.print %3 : f32
// CHECK: -inf
%4 = math.sqrt %1 : f32
%5 = arith.divf %one, %4 : f32
vector.print %5 : f32
Given that MathToLLVM uses 1/sqrt(x), I'd go the same route to be consistent with it. |
|
||
Location loc = op->getLoc(); | ||
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); | ||
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op.getOperand()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can pass the operand
var directly here.
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter); | ||
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op.getOperand()); | ||
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy, | ||
ValueRange{constOneFloat, sqrtOp}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you don't need to pass operandTy
. Can replace with rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some nits.
Fixed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I can't merge this (yet), will request write access. Can one of you please do? @pashu123 @srcarroll |
Hotfix for "[mlir][math] lower rsqrt to sqrt + fdiv (#91344)"
This commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)).