Skip to content

[mlir][math] lower rsqrt to sqrt + fdiv #91344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024

Conversation

cferry-AMD
Copy link
Contributor

This commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)).

@cferry-AMD
Copy link
Contributor Author

This also turns out to be MathToLLVM's current behavior.

@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Corentin Ferry (cferry-AMD)

Changes

This commit creates an expansion pattern to lower math.rsqrt(x) into fdiv(1, sqrt(x)).


Full diff: https://github.com/llvm/llvm-project/pull/91344.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+22)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+42)
  • (modified) mlir/test/lib/Dialect/Math/TestExpandMath.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 24e6d9a8d98e..ba6977251564 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -42,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
 void populateExpandFPowIPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
+void populateExpandRsqrtPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 5ccf3b6d72a2..05d32ad2bc3e 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -615,6 +615,24 @@ static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
   return success();
 }
 
+// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
+static LogicalResult convertRsqrtOp(math::RsqrtOp op,
+                                    PatternRewriter &rewriter) {
+
+  auto operand = op.getOperand();
+  auto operandTy = operand.getType();
+  auto eTy = getElementTypeOrSelf(operandTy);
+  if (!isa<FloatType>(eTy))
+    return failure();
+
+  Location loc = op->getLoc();
+  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
+  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
+  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
+                                             ValueRange{constOneFloat, sqrtOp});
+  return success();
+}
+
 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
   patterns.add(convertCtlzOp);
 }
@@ -678,3 +696,7 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
 void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundEvenOp);
 }
+
+void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
+  patterns.add(convertRsqrtOp);
+}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3d94b55126d0..d25f4e571e6a 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -658,3 +658,45 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
 // CHECK:       return %[[SEL]] : f32
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt
+// CHECK-SAME:     (%[[ARG:.*]]: f32)
+// CHECK-SAME:    -> f32
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : f32
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : f32
+// CHECK:         return %[[DIV]] : f32
+func.func @rsqrt(%float: f32) -> (f32)  {
+  %float_result = math.rsqrt %float : f32
+  return %float_result : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt_vec
+// CHECK-SAME:     (%[[ARG:.*]]: vector<5xf32>)
+// CHECK-SAME:    -> vector<5xf32>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<5xf32>
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : vector<5xf32>
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : vector<5xf32>
+// CHECK:         return %[[DIV]] : vector<5xf32>
+func.func @rsqrt_vec(%float: vector<5xf32>) -> (vector<5xf32>)  {
+  %float_result = math.rsqrt %float : vector<5xf32>
+  return %float_result : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @rsqrt_tns
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<5x8xf32>)
+// CHECK-SAME:    -> tensor<5x8xf32>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<5x8xf32>
+// CHECK-DAG:     %[[SQRT:.*]] = math.sqrt %[[ARG]] : tensor<5x8xf32>
+// CHECK-DAG:     %[[DIV:.*]] = arith.divf %[[CST]], %[[SQRT]] : tensor<5x8xf32>
+// CHECK:         return %[[DIV]] : tensor<5x8xf32>
+func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>)  {
+  %float_result = math.rsqrt %float : tensor<5x8xf32>
+  return %float_result : tensor<5x8xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index da48ccb6e5e0..69af2a08b97b 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -52,6 +52,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFPowIPattern(patterns);
   populateExpandRoundFPattern(patterns);
   populateExpandRoundEvenPattern(patterns);
+  populateExpandRsqrtPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 

@cferry-AMD
Copy link
Contributor Author

cferry-AMD commented May 7, 2024

For review: @srcarroll @rsuderman @jinchen62 @pashu123, feel free to mention anyone else relevant!


Location loc = op->getLoc();
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is preferred to use tablegen designated getter functions for operands and results (otherwise you need to check that you can even index getOperand(0).). so can do op.getOperand() instead (see https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Math/IR/MathOps.td#L44)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks!

Copy link
Contributor

@srcarroll srcarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one nit. otherwise looks good to me. thanks

@srcarroll
Copy link
Contributor

out of curiosity, why choose div + sqrt over sqrt + div? does it make a difference in terms of accuracy?

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto operand = op.getOperand();
auto operandTy = operand.getType();
auto eTy = getElementTypeOrSelf(operandTy);
if (!isa<FloatType>(eTy))
Copy link
Member

@pashu123 pashu123 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the expansion will work fine for f16, f32 and f64. You can enable them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't FloatType encompass all of those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes. I only saw the tests. Could you edit the tests and add the F16 or F64 ones, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- @srcarroll, you're right, F16 and F64 are indeed supported, let me add tests for them.

Location loc = op->getLoc();
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op->getOperand(0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please include an integration test to verify UB. For example, the test should check the behaviour when the input value is 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added within test-expand-math-approx.mlir.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments.

@cferry-AMD
Copy link
Contributor Author

out of curiosity, why choose div + sqrt over sqrt + div? does it make a difference in terms of accuracy?

I have actually not thought of that other way. Let me have a look... that's a good point!

@cferry-AMD
Copy link
Contributor Author

I have actually not thought of that other way. Let me have a look... that's a good point!

Shallow testing on 1/sqrt(x) for integer values of x did not point out significant accuracy differences. However, the behavior with respect to -0.0 looks different when hand-writing the two versions into test-expand-math-approx.mlir:

  %1 = arith.constant -0.0 : f32
  %one = arith.constant 1.0 : f32
  
  // CHECK: -nan
  %2 = arith.divf %one, %1 : f32
  %3 = math.sqrt %2 : f32
  vector.print %3 : f32

  // CHECK: -inf
  %4 = math.sqrt %1 : f32
  %5 = arith.divf %one, %4 : f32
  vector.print %5 : f32
  • div(1, sqrt(x)) will return -inf,
  • sqrt(div(1, x)) will return -nan.

Given that MathToLLVM uses 1/sqrt(x), I'd go the same route to be consistent with it.

@cferry-AMD cferry-AMD requested review from pashu123 and srcarroll May 8, 2024 08:28
@cferry-AMD cferry-AMD changed the title [math] lower rsqrt to sqrt + fdiv [mlir][math] lower rsqrt to sqrt + fdiv May 8, 2024

Location loc = op->getLoc();
auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op.getOperand());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can pass the operand var directly here.

auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
auto sqrtOp = rewriter.create<math::SqrtOp>(loc, op.getOperand());
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, operandTy,
ValueRange{constOneFloat, sqrtOp});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need to pass operandTy. Can replace with rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp).

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nits.

@cferry-AMD
Copy link
Contributor Author

Fixed!

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@cferry-AMD
Copy link
Contributor Author

I can't merge this (yet), will request write access. Can one of you please do? @pashu123 @srcarroll
Thanks!

@mgehre-amd mgehre-amd merged commit 279a659 into llvm:main May 13, 2024
4 checks passed
@mgehre-amd mgehre-amd deleted the corentin.upstream_lower_rsqrt branch May 13, 2024 08:15
mgehre-amd added a commit that referenced this pull request May 13, 2024
Hotfix for "[mlir][math] lower rsqrt to sqrt + fdiv (#91344)"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants