Skip to content

[mlir][math] Fix math.powf expansion case for pow(x, 0) #119015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

christopherbate
Copy link
Contributor

@christopherbate christopherbate commented Dec 6, 2024

Lowering math.powf to llvm.intr.powf will result in pow(x, 0) = 1, even for x=0. When using the Math dialect expansion patterns,
pow(0, 0) will result in -nan, however, This change adds two
additional instructions to the lowering to ensure the pow(x, 0) case
lowers to to 1 regardless of the value of x.

Resolves #118945.

@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

Lowering math.powf to llvm.intr.powf will result in pow(x, 0) = 1, even for x=0. When using the Math dialect expansion patterns,
pow(0, 0) will result in -nan, however, This change adds two
additional instructions to the lowering to ensure the pow(x, 0) case
lowers to to 1 regardless of the value of x.


Full diff: https://github.com/llvm/llvm-project/pull/119015.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+8)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+23-14)
  • (modified) mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir (+9-4)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 80569d95137c3a..8bcbdb4c9a664a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
   Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
       b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
   Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
+  // First, we select between the exp value and the adjusted value for odd
+  // powers of negatives. Then, we ensure that one is produced if `b` is zero.
+  // This corresponds to `libm` behavior, even for `0^0`. Without this check,
+  // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
+  Value zeroCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
   Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
                                         expResult);
+  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
   rewriter.replaceOp(op, res);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c10a78ca4ae4ca..89413b95703322 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
 func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
+  // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
   // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
   // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
   // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] 
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
   // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
   // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
   // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
   // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
   // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
   // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK: return [[SEL]]
+  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
 func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %1 = arith.constant dense<-3> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_even_power
 func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %1 = arith.constant dense<-4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
 func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %1 = arith.constant dense<5> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_even_power
 func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %1 = arith.constant dense<4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:      return %[[SEL]] : tensor<8xf32>
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:      return %[[SEL1]] : tensor<8xf32>
 
 // -----
 
@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK:        %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK:        %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:       return %[[SEL1]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..93de767b551769 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
 func.func @exp2f() {
   // CHECK: 2
   %a = arith.constant 1.0 : f64
-  call @func_exp2f(%a) : (f64) -> ()  
+  call @func_exp2f(%a) : (f64) -> ()
 
   // CHECK-NEXT: 4
   %b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
   // CHECK-NEXT: -nan
   %k = arith.constant 1.0 : f64
   %k_p = arith.constant 0xfff0000001000000 : f64
-  call @func_powff64(%k, %k_p) : (f64, f64) -> ()  
+  call @func_powff64(%k, %k_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: -nan
   %l = arith.constant 1.0 : f32
   %l_p = arith.constant 0xffffffff : f32
-  call @func_powff32(%l, %l_p) : (f32, f32) -> ()  
-  return  
+  call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+
+  // CHECK-NEXT: 1
+  %zero = arith.constant 0.0 : f32
+  call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+
+  return
 }
 
 // -------------------------------------------------------------------------- //

@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-mlir-math

Author: Christopher Bate (christopherbate)

Changes

Lowering math.powf to llvm.intr.powf will result in pow(x, 0) = 1, even for x=0. When using the Math dialect expansion patterns,
pow(0, 0) will result in -nan, however, This change adds two
additional instructions to the lowering to ensure the pow(x, 0) case
lowers to to 1 regardless of the value of x.


Full diff: https://github.com/llvm/llvm-project/pull/119015.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+8)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+23-14)
  • (modified) mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir (+9-4)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 80569d95137c3a..8bcbdb4c9a664a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
   Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
   Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
       b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
   Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
 
+  // First, we select between the exp value and the adjusted value for odd
+  // powers of negatives. Then, we ensure that one is produced if `b` is zero.
+  // This corresponds to `libm` behavior, even for `0^0`. Without this check,
+  // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
+  Value zeroCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
   Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
                                         expResult);
+  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
   rewriter.replaceOp(op, res);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c10a78ca4ae4ca..89413b95703322 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
 // CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
 func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
+  // CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
   // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
   // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
   // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
-  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]] 
+  // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
   // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
   // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
   // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
   // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
   // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
   // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
+  // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
   // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
-  // CHECK: return [[SEL]]
+  // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
+  // CHECK: return [[SEL1]]
   %ret = math.powf %a, %b : f64
   return %ret : f64
 }
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
 func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %1 = arith.constant dense<-3> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_neg_even_power
 func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %1 = arith.constant dense<-4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
 func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %1 = arith.constant dense<5> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
 
 // CHECK-LABEL:   func.func @math_fpowi_pos_even_power
 func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
-  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %1 = arith.constant dense<4> : tensor<8xi64>
   %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
   return %2 : tensor<8xf32>
 }
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
   return %2 : tensor<8xf32>
 }
 // CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK:        %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK:      return %[[SEL]] : tensor<8xf32>
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:      return %[[SEL1]] : tensor<8xf32>
 
 // -----
 
@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
   return %2 : f32
 }
 // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK:        %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK:        %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK:        %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK-DAG:    %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:    %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:    %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK:        %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
 // CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
 // CHECK:        %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
 // CHECK:        %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
 // CHECK:        %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
 // CHECK:        %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
+// CHECK:        %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
 // CHECK:        %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK:       return %[[SEL]] : f32
+// CHECK:        %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
+// CHECK:       return %[[SEL1]] : f32
 
 // -----
 
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..93de767b551769 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
 func.func @exp2f() {
   // CHECK: 2
   %a = arith.constant 1.0 : f64
-  call @func_exp2f(%a) : (f64) -> ()  
+  call @func_exp2f(%a) : (f64) -> ()
 
   // CHECK-NEXT: 4
   %b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
   // CHECK-NEXT: -nan
   %k = arith.constant 1.0 : f64
   %k_p = arith.constant 0xfff0000001000000 : f64
-  call @func_powff64(%k, %k_p) : (f64, f64) -> ()  
+  call @func_powff64(%k, %k_p) : (f64, f64) -> ()
 
   // CHECK-NEXT: -nan
   %l = arith.constant 1.0 : f32
   %l_p = arith.constant 0xffffffff : f32
-  call @func_powff32(%l, %l_p) : (f32, f32) -> ()  
-  return  
+  call @func_powff32(%l, %l_p) : (f32, f32) -> ()
+
+  // CHECK-NEXT: 1
+  %zero = arith.constant 0.0 : f32
+  call @func_powff32(%zero, %zero) : (f32, f32) -> ()
+
+  return
 }
 
 // -------------------------------------------------------------------------- //

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@christopherbate christopherbate force-pushed the mlir-math-powf-lowering-fix branch from 014b036 to 02ab816 Compare December 9, 2024 18:52
Lowering `math.powf` to `llvm.intr.powf` will result in `pow(x, 0) =
1`, even for `x=0`. When using the Math dialect expansion patterns,
`pow(0, 0)` will result in `-nan`, however, This change adds two
additional instructions to the lowering to ensure the `pow(x, 0)` case
lowers to to `1` regardless of the value of `x`.
@christopherbate christopherbate force-pushed the mlir-math-powf-lowering-fix branch from 02ab816 to 12e0731 Compare December 9, 2024 21:52
@christopherbate christopherbate merged commit a92e3df into llvm:main Dec 9, 2024
4 of 6 checks passed
@christopherbate christopherbate deleted the mlir-math-powf-lowering-fix branch December 9, 2024 21:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] Inconsistent output when executing MLIR program with and without -test-expand-math
3 participants