Skip to content

Commit a92e3df

Browse files
[mlir][math] Fix math.powf expansion case for pow(x, 0) (#119015)
Lowering `math.powf` to `llvm.intr.powf` will result in `pow(x, 0) = 1`, even for `x=0`. When using the Math dialect expansion patterns, `pow(0, 0)` will result in `-nan`, however, This change adds two additional instructions to the lowering to ensure the `pow(x, 0)` case lowers to to `1` regardless of the value of `x`. Resolves #118945.
1 parent 8471541 commit a92e3df

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
343343
Value operandB = op.getOperand(1);
344344
Type opType = operandA.getType();
345345
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
346+
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
346347
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
347348
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
348349
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
@@ -359,8 +360,15 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
359360
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
360361
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
361362

363+
// First, we select between the exp value and the adjusted value for odd
364+
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
365+
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
366+
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
367+
Value zeroCheck =
368+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
362369
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
363370
expResult);
371+
res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
364372
rewriter.replaceOp(op, res);
365373
return success();
366374
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,11 @@ func.func @roundf_func(%a: f32) -> f32 {
222222
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
223223
func.func @powf_func(%a: f64, %b: f64) ->f64 {
224224
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
225+
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
225226
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
226227
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
227228
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
228-
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
229+
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
229230
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
230231
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
231232
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
@@ -234,8 +235,10 @@ func.func @powf_func(%a: f64, %b: f64) ->f64 {
234235
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
235236
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
236237
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
238+
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
237239
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
238-
// CHECK: return [[SEL]]
240+
// CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
241+
// CHECK: return [[SEL1]]
239242
%ret = math.powf %a, %b : f64
240243
return %ret : f64
241244
}
@@ -516,7 +519,7 @@ func.func @roundeven16(%arg: f16) -> f16 {
516519

517520
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518521
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519-
%1 = arith.constant dense<-3> : tensor<8xi64>
522+
%1 = arith.constant dense<-3> : tensor<8xi64>
520523
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521524
return %2 : tensor<8xf32>
522525
}
@@ -539,7 +542,7 @@ func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
539542

540543
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
541544
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
542-
%1 = arith.constant dense<-4> : tensor<8xi64>
545+
%1 = arith.constant dense<-4> : tensor<8xi64>
543546
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
544547
return %2 : tensor<8xf32>
545548
}
@@ -562,7 +565,7 @@ func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
562565

563566
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
564567
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
565-
%1 = arith.constant dense<5> : tensor<8xi64>
568+
%1 = arith.constant dense<5> : tensor<8xi64>
566569
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
567570
return %2 : tensor<8xf32>
568571
}
@@ -576,7 +579,7 @@ func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
576579

577580
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
578581
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
579-
%1 = arith.constant dense<4> : tensor<8xi64>
582+
%1 = arith.constant dense<4> : tensor<8xi64>
580583
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
581584
return %2 : tensor<8xf32>
582585
}
@@ -617,9 +620,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
617620
return %2 : tensor<8xf32>
618621
}
619622
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
620-
// CHECK: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
621-
// CHECK: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
622-
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
623+
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
624+
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
625+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
626+
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
623627
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
624628
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
625629
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
@@ -631,8 +635,10 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
631635
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
632636
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
633637
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
638+
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
634639
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
635-
// CHECK: return %[[SEL]] : tensor<8xf32>
640+
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
641+
// CHECK: return %[[SEL1]] : tensor<8xf32>
636642

637643
// -----
638644

@@ -642,9 +648,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
642648
return %2 : f32
643649
}
644650
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
645-
// CHECK: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
646-
// CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
647-
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
651+
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
652+
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
653+
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
654+
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
648655
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
649656
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
650657
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
@@ -656,8 +663,10 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
656663
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
657664
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
658665
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
666+
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
659667
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
660-
// CHECK: return %[[SEL]] : f32
668+
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
669+
// CHECK: return %[[SEL1]] : f32
661670

662671
// -----
663672

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @func_exp2f(%a : f64) {
1818
func.func @exp2f() {
1919
// CHECK: 2
2020
%a = arith.constant 1.0 : f64
21-
call @func_exp2f(%a) : (f64) -> ()
21+
call @func_exp2f(%a) : (f64) -> ()
2222

2323
// CHECK-NEXT: 4
2424
%b = arith.constant 2.0 : f64
@@ -240,13 +240,18 @@ func.func @powf() {
240240
// CHECK-NEXT: -nan
241241
%k = arith.constant 1.0 : f64
242242
%k_p = arith.constant 0xfff0000001000000 : f64
243-
call @func_powff64(%k, %k_p) : (f64, f64) -> ()
243+
call @func_powff64(%k, %k_p) : (f64, f64) -> ()
244244

245245
// CHECK-NEXT: -nan
246246
%l = arith.constant 1.0 : f32
247247
%l_p = arith.constant 0xffffffff : f32
248-
call @func_powff32(%l, %l_p) : (f32, f32) -> ()
249-
return
248+
call @func_powff32(%l, %l_p) : (f32, f32) -> ()
249+
250+
// CHECK-NEXT: 1
251+
%zero = arith.constant 0.0 : f32
252+
call @func_powff32(%zero, %zero) : (f32, f32) -> ()
253+
254+
return
250255
}
251256

252257
// -------------------------------------------------------------------------- //

0 commit comments

Comments
 (0)