-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][math] Fix math.powf
expansion case for pow(x, 0)
#119015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][math] Fix math.powf
expansion case for pow(x, 0)
#119015
Conversation
@llvm/pr-subscribers-mlir Author: Christopher Bate (christopherbate) ChangesLowering Full diff: https://github.com/llvm/llvm-project/pull/119015.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 80569d95137c3a..8bcbdb4c9a664a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+ // First, we select between the exp value and the adjusted value for odd
+ // powers of negatives. Then, we ensure that one is produced if `b` is zero.
+ // This corresponds to `libm` behavior, even for `0^0`. Without this check,
+ // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
+ Value zeroCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
expResult);
+ res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
rewriter.replaceOp(op, res);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c10a78ca4ae4ca..89413b95703322 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
+ // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
- // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
- // CHECK: return [[SEL]]
+ // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+ // CHECK: return [[SEL1]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<-3> : tensor<8xi64>
+ %1 = arith.constant dense<-3> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<-4> : tensor<8xi64>
+ %1 = arith.constant dense<-4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<5> : tensor<8xi64>
+ %1 = arith.constant dense<5> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<4> : tensor<8xi64>
+ %1 = arith.constant dense<4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]] : tensor<8xf32>
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : tensor<8xf32>
// -----
@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK: return %[[SEL]] : f32
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : f32
// -----
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..93de767b551769 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
func.func @exp2f() {
// CHECK: 2
%a = arith.constant 1.0 : f64
- call @func_exp2f(%a) : (f64) -> ()
+ call @func_exp2f(%a) : (f64) -> ()
// CHECK-NEXT: 4
%b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
// CHECK-NEXT: -nan
%k = arith.constant 1.0 : f64
%k_p = arith.constant 0xfff0000001000000 : f64
- call @func_powff64(%k, %k_p) : (f64, f64) -> ()
+ call @func_powff64(%k, %k_p) : (f64, f64) -> ()
// CHECK-NEXT: -nan
%l = arith.constant 1.0 : f32
%l_p = arith.constant 0xffffffff : f32
- call @func_powff32(%l, %l_p) : (f32, f32) -> ()
- return
+ call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+
+ // CHECK-NEXT: 1
+ %zero = arith.constant 0.0 : f32
+ call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+
+ return
}
// -------------------------------------------------------------------------- //
|
@llvm/pr-subscribers-mlir-math Author: Christopher Bate (christopherbate) ChangesLowering Full diff: https://github.com/llvm/llvm-project/pull/119015.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 80569d95137c3a..8bcbdb4c9a664a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+ // First, we select between the exp value and the adjusted value for odd
+ // powers of negatives. Then, we ensure that one is produced if `b` is zero.
+ // This corresponds to `libm` behavior, even for `0^0`. Without this check,
+ // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
+ Value zeroCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
expResult);
+ res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
rewriter.replaceOp(op, res);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c10a78ca4ae4ca..89413b95703322 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
+ // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
- // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
+ // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+ // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
- // CHECK: return [[SEL]]
+ // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+ // CHECK: return [[SEL1]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<-3> : tensor<8xi64>
+ %1 = arith.constant dense<-3> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<-4> : tensor<8xi64>
+ %1 = arith.constant dense<-4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<5> : tensor<8xi64>
+ %1 = arith.constant dense<5> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
- %1 = arith.constant dense<4> : tensor<8xi64>
+ %1 = arith.constant dense<4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: return %[[SEL]] : tensor<8xf32>
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : tensor<8xf32>
// -----
@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK: return %[[SEL]] : f32
+// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK: return %[[SEL1]] : f32
// -----
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..93de767b551769 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
func.func @exp2f() {
// CHECK: 2
%a = arith.constant 1.0 : f64
- call @func_exp2f(%a) : (f64) -> ()
+ call @func_exp2f(%a) : (f64) -> ()
// CHECK-NEXT: 4
%b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
// CHECK-NEXT: -nan
%k = arith.constant 1.0 : f64
%k_p = arith.constant 0xfff0000001000000 : f64
- call @func_powff64(%k, %k_p) : (f64, f64) -> ()
+ call @func_powff64(%k, %k_p) : (f64, f64) -> ()
// CHECK-NEXT: -nan
%l = arith.constant 1.0 : f32
%l_p = arith.constant 0xffffffff : f32
- call @func_powff32(%l, %l_p) : (f32, f32) -> ()
- return
+ call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+
+ // CHECK-NEXT: 1
+ %zero = arith.constant 0.0 : f32
+ call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+
+ return
}
// -------------------------------------------------------------------------- //
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
014b036
to
02ab816
Compare
Lowering `math.powf` to `llvm.intr.powf` will result in `pow(x, 0) = 1`, even for `x=0`. When using the Math dialect expansion patterns, `pow(0, 0)` will result in `-nan`, however, This change adds two additional instructions to the lowering to ensure the `pow(x, 0)` case lowers to to `1` regardless of the value of `x`.
02ab816
to
12e0731
Compare
Lowering
math.powf
tollvm.intr.powf
will result inpow(x, 0) = 1
, even forx=0
. When using the Math dialect expansion patterns,pow(0, 0)
will result in-nan
, however, This change adds twoadditional instructions to the lowering to ensure the
pow(x, 0)
caselowers to to
1
regardless of the value ofx
.Resolves #118945.