Skip to content

[MLIR] Lower math.powf(x, 3.0) to x * x * x. #127256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged

[MLIR] Lower math.powf(x, 3.0) to x * x * x. #127256

merged 1 commit into from
Feb 14, 2025

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Feb 14, 2025

math.powf(x, y) never really supported negative values of x, but that was unclear (happened to work for some values of y) until #126338 was merged yesterday and lowered it to the usual exp(y * log(x)) outside of a few special exponent values, such as y == 2.0 lowering to x * x.

It turns out that code in the wild has been relying on math.powf(x, y) with negative x for some integral values of y for which a lowering to muls was intended: iree-org/iree#19996

This PR adds such a lowering for y == 3.0. It "fixes" such cases, and it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using powf with negative x, use math.fpowi for that.

Signed-off-by: Benoit Jacob <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

Changes

math.powf(x, y) never really supported negative values of x, but that was unclear (happened to work for some values of y) until #126338 was merged yesterday and lowered it to the usual exp(y * log(x)) outside of a few special exponent values, such as y == 2.0lowering tox * x`.

It turns out that code in the wild has been relying on math.powf(x, y) with negative x for some integral values of y for which a lowering to muls was intended: iree-org/iree#19996

This PR adds such a lowering for y == 3.0. It "fixes" such cases, and it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using powf with negative x, use math.fpowi for that.


Full diff: https://github.com/llvm/llvm-project/pull/127256.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+9-4)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+11)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   auto &sem =
       cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
   APFloat valueB(sem);
+  auto mulf = [&](Value x, Value y) -> Value {
+    return b.create<arith::MulFOp>(x, y);
+  };
   if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
     if (valueB.isZero()) {
       // a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
     }
     if (valueB.isExactlyValue(2.0)) {
       // a^2 -> a * a
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
-      rewriter.replaceOp(op, mul);
+      rewriter.replaceOp(op, mulf(operandA, operandA));
       return success();
     }
     if (valueB.isExactlyValue(-2.0)) {
       // a^(-2) -> 1 / (a * a)
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
       Value one =
           createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-      Value div = b.create<arith::DivFOp>(one, mul);
+      Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
       rewriter.replaceOp(op, div);
       return success();
     }
+    if (valueB.isExactlyValue(3.0)) {
+      rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+      return success();
+    }
   }
 
   Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
   return %ret : f64
 }
 
+// CHECK-LABEL:   func @powf_func_three
+// CHECK-SAME:    (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+  // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+  // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+  // CHECK: return %[[MUL2]] : f64
+  %b = arith.constant 3.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @roundeven64

@llvmbot
Copy link
Member

llvmbot commented Feb 14, 2025

@llvm/pr-subscribers-mlir-math

Author: Benoit Jacob (bjacob)

Changes

math.powf(x, y) never really supported negative values of x, but that was unclear (happened to work for some values of y) until #126338 was merged yesterday and lowered it to the usual exp(y * log(x)) outside of a few special exponent values, such as y == 2.0lowering tox * x`.

It turns out that code in the wild has been relying on math.powf(x, y) with negative x for some integral values of y for which a lowering to muls was intended: iree-org/iree#19996

This PR adds such a lowering for y == 3.0. It "fixes" such cases, and it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using powf with negative x, use math.fpowi for that.


Full diff: https://github.com/llvm/llvm-project/pull/127256.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+9-4)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+11)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   auto &sem =
       cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
   APFloat valueB(sem);
+  auto mulf = [&](Value x, Value y) -> Value {
+    return b.create<arith::MulFOp>(x, y);
+  };
   if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
     if (valueB.isZero()) {
       // a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
     }
     if (valueB.isExactlyValue(2.0)) {
       // a^2 -> a * a
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
-      rewriter.replaceOp(op, mul);
+      rewriter.replaceOp(op, mulf(operandA, operandA));
       return success();
     }
     if (valueB.isExactlyValue(-2.0)) {
       // a^(-2) -> 1 / (a * a)
-      Value mul = b.create<arith::MulFOp>(operandA, operandA);
       Value one =
           createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
-      Value div = b.create<arith::DivFOp>(one, mul);
+      Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
       rewriter.replaceOp(op, div);
       return success();
     }
+    if (valueB.isExactlyValue(3.0)) {
+      rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+      return success();
+    }
   }
 
   Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
   return %ret : f64
 }
 
+// CHECK-LABEL:   func @powf_func_three
+// CHECK-SAME:    (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+  // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+  // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+  // CHECK: return %[[MUL2]] : f64
+  %b = arith.constant 3.0 : f64
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @roundeven64

@bjacob bjacob merged commit 0301bf9 into llvm:main Feb 14, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
`math.powf(x, y)` never really supported negative values of `x`, but
that was unclear (happened to work for some values of `y`) until
llvm#126338 was merged yesterday
and lowered it to the usual `exp(y * log(x))` outside of a few special
exponent values, such as y == 2.0` lowering to `x * x`.

It turns out that code in the wild has been relying on `math.powf(x, y)`
with negative `x` for some integral values of `y` for which a lowering
to muls was intended: iree-org/iree#19996

This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and
it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using `powf` with
negative `x`, use `math.fpowi` for that.

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants