-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Revert "[mlir][math]Update convertPowfOp
ExpandPatterns.cpp
"
#126063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
)" This reverts commit 3a33775.
@llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesReverts llvm/llvm-project#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue. Full diff: https://github.com/llvm/llvm-project/pull/126063.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 30bcdfc45837a65..3dadf9474cf4f67 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,8 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-// Restricting a >= 0
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -320,10 +319,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+ Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+ Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
- Value logA = b.create<math::LogOp>(opType, operandA);
- Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+ Value logA = b.create<math::LogOp>(opType, opASquared);
+ Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
+ Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+ Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+ Value negCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value oddPower =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+ Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -331,9 +341,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value finalResult =
- b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
- rewriter.replaceOp(op, finalResult);
+ Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+ expResult);
+ res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
+ rewriter.replaceOp(op, res);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 5b443e9e8d4e78e..6055ed0504c84ca 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,15 +202,25 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) -> f64 {
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
- // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
- // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
- // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
- // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
- // CHECK: return [[SEL]]
+ // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+ // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+ // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+ // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+ // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
+ // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+ // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+ // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+ // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+ // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+ // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+ // CHECK: return [[SEL1]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -592,15 +602,26 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]]
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : tensor<8xf32>
+
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -609,15 +630,25 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
+// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
-// CHECK: return %[[SEL]] : f32
+// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : f32
// -----
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index d1916c28878b97a..106b48a2daea2e3 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,6 +202,11 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+ // CHECK-NEXT: -27
+ %b = arith.constant -3.0 : f64
+ %b_p = arith.constant 3.0 : f64
+ call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
|
@llvm/pr-subscribers-mlir-math Author: Han-Chung Wang (hanhanW) ChangesReverts llvm/llvm-project#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue. Full diff: https://github.com/llvm/llvm-project/pull/126063.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 30bcdfc45837a65..3dadf9474cf4f67 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -311,8 +311,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
-// Restricting a >= 0
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -320,10 +319,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+ Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
+ Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
- Value logA = b.create<math::LogOp>(opType, operandA);
- Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+ Value logA = b.create<math::LogOp>(opType, opASquared);
+ Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
+ Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
+ Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
+ Value negCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value oddPower =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
+ Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -331,9 +341,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value finalResult =
- b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
- rewriter.replaceOp(op, finalResult);
+ Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
+ expResult);
+ res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
+ rewriter.replaceOp(op, res);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 5b443e9e8d4e78e..6055ed0504c84ca 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,15 +202,25 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) -> f64 {
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
- // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
- // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
- // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
- // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
- // CHECK: return [[SEL]]
+ // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
+ // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
+ // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
+ // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
+ // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
+ // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+ // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
+ // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
+ // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+ // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
+ // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+ // CHECK: return [[SEL1]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -592,15 +602,26 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]]
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
+// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : tensor<8xf32>
+
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -609,15 +630,25 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
+// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
+// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
-// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
-// CHECK: return %[[SEL]] : f32
+// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
+// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
+// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
+// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
+// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : f32
// -----
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index d1916c28878b97a..106b48a2daea2e3 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,6 +202,11 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+ // CHECK-NEXT: -27
+ %b = arith.constant -3.0 : f64
+ %b_p = arith.constant 3.0 : f64
+ call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
|
…m#126063) Reverts llvm#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.
Reverts #124402
It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.