Skip to content

[mlir][math] powf(a, b) drop support when a < 0 #126338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 13, 2025

Conversation

ita9naiwa
Copy link
Contributor

@ita9naiwa ita9naiwa commented Feb 8, 2025

Related: #124402

  • change inefficient implementation of powf(a, b) to handle a < 0 case
    • thus drop a < 0 case support

However, some special cases are being used such as:

  • a < 0 and b = 0, b = 1 or b % 2 == 0
  • convert those special cases into simpler ops.

The current implementation of `convertPowfOp` requires a calculation of
`a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will
overflow so get INF in fp8 or fp16 easily.


Remove support when `a < 0`. Overhead of handling negative value of `a`
is large and easy to overflow;

- related issue in iree:
iree-org/iree#15936
@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2025

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

Related: #124402

  • change inefficient implementation of powf(a, b) to handle a &lt; 0 case

    • thus drop a &lt; 0 case support
  • some special cases are being used such as:

    • a &lt; 0 &amp;&amp; b % 2 == 0
  • convert those special cases into simpler ops.


Full diff: https://github.com/llvm/llvm-project/pull/126338.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+97-18)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+20-51)
  • (modified) mlir/test/mlir-runner/test-expand-math-approx.mlir (-5)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f67..235ea38dd87d1c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -17,8 +17,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cmath>
 
 using namespace mlir;
 
@@ -311,7 +316,92 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+                                          PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  auto baseType = operandB.getType();
+
+  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+                  .getFloatSemantics();
+
+  auto valueB = APFloat(sem);
+  if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+    // Not a constant, return failure
+    return failure();
+  }
+  float floatValueB = valueB.convertToFloat();
+
+  if (floatValueB == 1.0f) {
+    // a^1 -> a
+    rewriter.replaceOp(op, operandA);
+    return success();
+  }
+
+  if (floatValueB == 0.0) {
+    // a^0 -> 1
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    rewriter.replaceOp(op, one);
+    return success();
+  }
+
+  if (floatValueB == 0.5f) {
+    // a^(1/2) -> sqrt(a)
+    Value sqrt = b.create<math::SqrtOp>(operandA);
+    rewriter.replaceOp(op, sqrt);
+    return success();
+  }
+
+  if (floatValueB == -0.5f) {
+    // a^(-1/2) -> 1 / sqrt(a)
+    Value rsqrt = b.create<math::RsqrtOp>(operandA);
+    rewriter.replaceOp(op, rsqrt);
+    return success();
+  }
+
+  if (floatValueB == -1.0f) {
+    // a^(-1) -> 1 / a
+    Value one =
+        createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+    Value div = b.create<arith::DivFOp>(one, operandA);
+    rewriter.replaceOp(op, div);
+    return success();
+  }
+
+  // Check if the power is an integer
+  if (floatValueB != std::floor(floatValueB)) {
+    // We don't handle non-integer powers here, return failure
+    return failure();
+  }
+
+  auto sign = std::signbit(floatValueB) ? -1 : 1;
+  auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+  auto cstOne =
+      createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+  auto base = operandA;
+  if (sign == -1) {
+    base = b.create<arith::DivFOp>(cstOne, base);
+  }
+  auto current = base;
+  auto result = cstOne;
+  while (absIntValueB > 0) {
+    if (absIntValueB & 1) {
+      result = b.create<arith::MulFOp>(result, current);
+    }
+    current = b.create<arith::MulFOp>(current, current);
+    absIntValueB >>= 1;
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -319,21 +409,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
-  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -341,10 +420,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value finalResult =
+      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 
@@ -660,6 +738,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
 }
 
 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertSpecialPowfOp);
   patterns.add(convertPowfOp);
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84ca..5b443e9e8d4e78e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
-  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
-  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
-  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
-  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
-  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
-  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
-  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
-  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
-  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
-  // CHECK: return [[SEL1]]
+  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
+  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
+  // CHECK: return [[SEL]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:      return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL]]
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
-// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK:       return %[[SEL1]] : f32
+// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
+// CHECK:       return %[[SEL]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2e3..d1916c28878b97a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,11 +202,6 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
-  // CHECK-NEXT: -27
-  %b   = arith.constant -3.0 : f64
-  %b_p = arith.constant 3.0 : f64
-  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
   // CHECK-NEXT: 2.343
   %c   = arith.constant 2.343 : f64
   %c_p = arith.constant 1.000 : f64

@ita9naiwa
Copy link
Contributor Author

ita9naiwa commented Feb 8, 2025

I want this looked at before I start changing lit test cases

@ScottTodd ScottTodd requested a review from bjacob February 8, 2025 04:21
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
// Restricting a >= 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this actually checked? This seems to be expanding under this assumption, but always does it. Is this a new assumption on the op that should be documented?

Copy link
Contributor Author

@ita9naiwa ita9naiwa Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry!. I forgot to describe PR in more detail.

  1. I believe it should be documented, in general, powf(a, b) where a < 0 generally yields NaN and we (as far as I know) aren't able to check it runtime.

  1. This transform should be applied to some small number of 'b' (e.g., when 'abs(b) < 16')
  while (absIntValueB > 0) {
    if (absIntValueB & 1) {
      result = b.create<arith::MulFOp>(result, current);
    }
    current = b.create<arith::MulFOp>(current, current);
    absIntValueB >>= 1;
  }
  rewriter.replaceOp(op, result);
  return success();

The heuristic number for b is not determined yet. This last case can be dropped if it's not necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem is that there are some special use-cases where var a < 0 but const b == some multiple of 2 cc @hanhanW

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just drop the comment // Restricting a >= 0 here.

Mathematically, the power operation a^b, is well-defined in two separate (though overlapping) cases:

  1. When a > 0. In that case, a^b is defined as exp(b * ln(a)).
  2. When b is an integer. In that case, a^b is defined as a * ... * a, (b times), or the reciprocal of that if b is negative.

These two definitions agree in the intersection of these two cases.

Because "power" has inherently that two-mode definition, the MLIR op powf should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0.

I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf ops have been broken outside of the case a > 0, suggesting that no one was relying on that.

But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)).

}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
// Restricting a >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just drop the comment // Restricting a >= 0 here.

Mathematically, the power operation a^b, is well-defined in two separate (though overlapping) cases:

  1. When a > 0. In that case, a^b is defined as exp(b * ln(a)).
  2. When b is an integer. In that case, a^b is defined as a * ... * a, (b times), or the reciprocal of that if b is negative.

These two definitions agree in the intersection of these two cases.

Because "power" has inherently that two-mode definition, the MLIR op powf should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0.

I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf ops have been broken outside of the case a > 0, suggesting that no one was relying on that.

But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)).

Comment on lines 375 to 376
// Check if the power is an integer
if (floatValueB != std::floor(floatValueB)) {
Copy link
Contributor

@bjacob bjacob Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't try to handle arbitrary integer x. Just handle the special value 2.0, and maybe also 3.0 and 4.0 if you want, but that's it. If someone really needs a larger integral exponent to be match, we can always expand this pattern later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's handle few cases for now and document it in the function comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept only |b|=2.0 case and removed other integer cases.

@@ -660,6 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
}

void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertSpecialPowfOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for MLIR experts: Here, we want the convertSpecialPowfOp to have precedence over the convertPowfOp pattern. Is that ensured by it being added first here? If not, do we need to merge these two patterns to ensure ordering?

Copy link
Contributor Author

@ita9naiwa ita9naiwa Feb 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
This would explicitly give the order we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.

patterns.add(convertSpecialPowfOp);
patterns.add(convertPowfOp);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@hanhanW hanhanW Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.

This is correct, but I think it is not documented. IMO, we prefer using benefit to prioritize the patterns.

https://mlir.llvm.org/docs/PatternRewriter/#benefit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to give explicit benefit=2!

Copy link

github-actions bot commented Feb 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ita9naiwa ita9naiwa requested review from bjacob and hanhanW February 12, 2025 02:33
auto opType = operandA.getType();
auto baseType = operandB.getType();

auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, since math.powf requires float arguments (I just checked MathOps.td), the cast really shouldn't ever fail, so I think you can simply use cast instead of dyn_cast. You weren't checking for a null return value from dyn_cast anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 321 to 322
auto opType = operandA.getType();
auto baseType = operandB.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename opType to typeA, and baseType to typeB ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

// Not a constant, return failure
return failure();
}
float floatValueB = valueB.convertToFloat();
Copy link
Contributor

@bjacob bjacob Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid converting to C++ float, as the actual type could have higher precision, so this conversion would be rounding to lesser precision and could end up enabling an incorrect rewrite, e.g. if the type is f64, then the rewrite from powf(a, 1.0 + 1.0e-10) into a would be incorrect as, in f64, 1.0 + 1.0e-10 != 1.0, but the rounded floatValueB is exactly 1.0.

Can you keep the whole logic in this function in APFloat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, but this is might be bit verbose so could you check it please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanhanW , i am not familiar myself with APFloat, can you review that aspect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with it either. After skimming through the doc, I think we can use isExactlyValue method. I think we do not have precision issue for 0, +-1, +-0.5, +-2 numbers.

/// We don't rely on operator== working on double values, as
/// it returns true for things that are clearly not equal, like -0.0 and 0.0.
/// As such, this method can be used to do an exact bit-for-bit comparison of
/// two floating point values.
///
/// We leave the version with the double argument here because it's just so
/// convenient to write "2.0" and the like. Without this function we'd
/// have to duplicate its logic everywhere it's called.
bool isExactlyValue(double V) const {
bool ignored;
APFloat Tmp(V);
Tmp.convert(getSemantics(), APFloat::rmNearestTiesToEven, &ignored);
return bitwiseIsEqual(Tmp);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@ita9naiwa
Copy link
Contributor Author

I also think that it's might be great to merge convertSpecialPowfOp and convertPowfOp and keep it in a one function...?

@bjacob
Copy link
Contributor

bjacob commented Feb 12, 2025

I also think that it's might be great to merge convertSpecialPowfOp and convertPowfOp and keep it in a one function...?

@ita9naiwa , yes, good idea. That would be simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least we need all the special cases are tested in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my apology for that tihs PR got verbose, I made appropriate tests and runs well!!

@ita9naiwa
Copy link
Contributor Author

I added appropriate tests, and I think it's now ready to land!

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one question about getFloatSemantics.

Comment on lines +325 to +327
auto &sem =
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not pretty sure, but do we really need to get the semantics? I feel that the matcher already does the work for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/ita9naiwa/llvm-project/blob/d5ee522a217359264d35516a5b506149b5d1ab68/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp#L286-L288
other parts the same file explicitly use sem so I followed. It would work without it, should we remove?

Comment on lines 258 to 259
// CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64
// CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how this test succeeds: The C++ code in the rewrite is correctly generating a math.rsqrt, which is indeed better than a math.sqrt + arith.divf. So how is this test, which requires the latter, passing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

powf decomposed into rsqrt, then rsqrt decompose further into these ops, I don't understand where it occurs and only capture the first transform

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see. Thanks for the explanation.

Comment on lines 206 to 209
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64
// CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64
// CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64
// CHECK: return [[EXP]] : f64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow the %[[XXX:.+]] style because it is more common in MLIR codebase. One of the benefits is that we do not need to escape [ in the [%[[XXX]]] case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed other tests which are using %[[VAR:%.+]], I would fix.

// CHECK-LABEL:     func @ceilf_func
// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
func.func @ceilf_func(%a: f64) -> f64 {
  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
  // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
  // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
  // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
  // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
  // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
  // CHECK-NEXT:   return [[ADDF]]
  %ret = math.ceil %a : f64
  return %ret : f64
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed, thanks now I see how file-check test works.

@bjacob bjacob self-requested a review February 13, 2025 02:33
@hanhanW hanhanW merged commit de09986 into llvm:main Feb 13, 2025
8 checks passed
@ita9naiwa ita9naiwa deleted the ita9naiwa/powf branch February 13, 2025 23:07
bjacob added a commit that referenced this pull request Feb 14, 2025
`math.powf(x, y)` never really supported negative values of `x`, but
that was unclear (happened to work for some values of `y`) until
#126338 was merged yesterday
and lowered it to the usual `exp(y * log(x))` outside of a few special
exponent values, such as y == 2.0` lowering to `x * x`.

It turns out that code in the wild has been relying on `math.powf(x, y)`
with negative `x` for some integral values of `y` for which a lowering
to muls was intended: iree-org/iree#19996

This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and
it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using `powf` with
negative `x`, use `math.fpowi` for that.

Signed-off-by: Benoit Jacob <[email protected]>
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Related: llvm#124402

- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
  - thus drop `a < 0` case support

However, some special cases are being used such as:
  - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
  - convert those special cases into simpler ops.
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 14, 2025
`math.powf(x, y)` never really supported negative values of `x`, but
that was unclear (happened to work for some values of `y`) until
llvm/llvm-project#126338 was merged yesterday
and lowered it to the usual `exp(y * log(x))` outside of a few special
exponent values, such as y == 2.0` lowering to `x * x`.

It turns out that code in the wild has been relying on `math.powf(x, y)`
with negative `x` for some integral values of `y` for which a lowering
to muls was intended: iree-org/iree#19996

This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and
it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using `powf` with
negative `x`, use `math.fpowi` for that.

Signed-off-by: Benoit Jacob <[email protected]>
zjgarvey added a commit to llvm/torch-mlir that referenced this pull request Feb 19, 2025
Pure floating point pow operations no-longer support negative base
values (see <llvm/llvm-project#126338>), but
many models coming from ONNX use floating point representations of
integers as the exponent.

This change:

1. matches on constant rank-0 exponents and converts them to scalar
constants.
2. matches on constant floating-point scalar exponents and converts them
to ints if possible.
3. lowers `Tensor(float)^int` cases to `math.fpowi` 

Addresses some remaining test failures related to
<iree-org/iree#19996>.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Related: llvm#124402

- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
  - thus drop `a < 0` case support

However, some special cases are being used such as:
  - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
  - convert those special cases into simpler ops.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
`math.powf(x, y)` never really supported negative values of `x`, but
that was unclear (happened to work for some values of `y`) until
llvm#126338 was merged yesterday
and lowered it to the usual `exp(y * log(x))` outside of a few special
exponent values, such as y == 2.0` lowering to `x * x`.

It turns out that code in the wild has been relying on `math.powf(x, y)`
with negative `x` for some integral values of `y` for which a lowering
to muls was intended: iree-org/iree#19996

This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and
it is a more efficient lowering anyway.

There needs to be a wider project to stop altogether using `powf` with
negative `x`, use `math.fpowi` for that.

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants