Skip to content

Revert "[mlir][math]Update convertPowfOp ExpandPatterns.cpp" #126063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 6, 2025

Reverts #124402

It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.

@hanhanW hanhanW merged commit c9d0a46 into main Feb 6, 2025
7 of 9 checks passed
@hanhanW hanhanW deleted the revert-124402-main branch February 6, 2025 13:19
@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Reverts llvm/llvm-project#124402

It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.


Full diff: https://github.com/llvm/llvm-project/pull/126063.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+18-7)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+51-20)
  • (modified) mlir/test/mlir-runner/test-expand-math-approx.mlir (+5)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 30bcdfc45837a65..3dadf9474cf4f67 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,8 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-// Restricting a >= 0
+// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -320,10 +319,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, operandA);
-  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+  Value logA = b.create<math::LogOp>(opType, opASquared);
+  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
+  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+  Value negCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value oddPower =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -331,9 +341,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value finalResult =
-      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
-  rewriter.replaceOp(op, finalResult);
+  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+                                        expResult);
+  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 5b443e9e8d4e78e..6055ed0504c84ca 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,15 +202,25 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) -> f64 {
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
-  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
-  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
-  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
-  // CHECK: return [[SEL]]
+  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
+  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -592,15 +602,26 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]]
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
+// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:      return %[[SEL1]] : tensor<8xf32>
+
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -609,15 +630,25 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:       return %[[SEL1]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index d1916c28878b97a..106b48a2daea2e3 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,6 +202,11 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
+  // CHECK-NEXT: -27
+  %b   = arith.constant -3.0 : f64
+  %b_p = arith.constant 3.0 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
   // CHECK-NEXT: 2.343
   %c   = arith.constant 2.343 : f64
   %c_p = arith.constant 1.000 : f64

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir-math

Author: Han-Chung Wang (hanhanW)

Changes

Reverts llvm/llvm-project#124402

It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.


Full diff: https://github.com/llvm/llvm-project/pull/126063.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+18-7)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+51-20)
  • (modified) mlir/test/mlir-runner/test-expand-math-approx.mlir (+5)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 30bcdfc45837a65..3dadf9474cf4f67 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,8 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
   return success();
 }
 
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-// Restricting a >= 0
+// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -320,10 +319,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
 
-  Value logA = b.create<math::LogOp>(opType, operandA);
-  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+  Value logA = b.create<math::LogOp>(opType, opASquared);
+  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
   Value expResult = b.create<math::ExpOp>(opType, mult);
+  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+  Value negCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value oddPower =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -331,9 +341,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
   Value zeroCheck =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value finalResult =
-      b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
-  rewriter.replaceOp(op, finalResult);
+  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+                                        expResult);
+  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 5b443e9e8d4e78e..6055ed0504c84ca 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,15 +202,25 @@ func.func @roundf_func(%a: f32) -> f32 {
 
 // CHECK-LABEL:   func @powf_func
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) -> f64 {
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
   // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
-  // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
-  // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
-  // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
-  // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
-  // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
-  // CHECK: return [[SEL]]
+  // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+  // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+  // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
+  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+  // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+  // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+  // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+  // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+  // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -592,15 +602,26 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]]
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
+// CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:      return %[[SEL1]] : tensor<8xf32>
+
 // -----
 
 // CHECK-LABEL:   func.func @math_fpowi_to_powf_scalar
@@ -609,15 +630,25 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
 // CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK:        %[[LOGA:.*]] = math.log %[[ARG0]] : f32
-// CHECK:        %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[LG:.*]] = math.log %[[SQ]] : f32
+// CHECK:        %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
 // CHECK:        %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK:        %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK:        %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:        %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK:        %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
+// CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:       return %[[SEL1]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index d1916c28878b97a..106b48a2daea2e3 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,6 +202,11 @@ func.func @powf() {
   %a_p = arith.constant 2.0 : f64
   call @func_powff64(%a, %a_p) : (f64, f64) -> ()
 
+  // CHECK-NEXT: -27
+  %b   = arith.constant -3.0 : f64
+  %b_p = arith.constant 3.0 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
   // CHECK-NEXT: 2.343
   %c   = arith.constant 2.343 : f64
   %c_p = arith.constant 1.000 : f64

Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…m#126063)

Reverts llvm#124402

It breaks an integration test in downstream project (i.e., IREE), which
produces NANs. Talked to the author @ita9naiwa, and we agree to reland
the PR after we find the issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants