Skip to content

Commit c477e70

Browse files
committed
remove case a<0
1 parent 19ba511 commit c477e70

File tree

3 files changed

+23
-66
lines changed

3 files changed

+23
-66
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,39 +311,28 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
311311
return success();
312312
}
313313

314-
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|))
315-
// * sign(a)^b
314+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
315+
// Restricting a >= 0
316316
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
317317
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
318318
Value operandA = op.getOperand(0);
319319
Value operandB = op.getOperand(1);
320320
Type opType = operandA.getType();
321321
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
322322
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
323-
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
324-
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
325-
326-
Value absA = b.create<math::AbsFOp>(opType, operandA);
327-
Value isNegative =
328-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
329-
Value signA =
330-
b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
331-
Value logA = b.create<math::LogOp>(opType, absA);
323+
324+
Value logA = b.create<math::LogOp>(opType, operandA);
332325
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
333326
Value expResult = b.create<math::ExpOp>(opType, mult);
334-
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
335-
Value isOdd =
336-
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
337-
Value signedExpResult = b.create<arith::SelectOp>(
338-
op->getLoc(), isOdd, b.create<arith::MulFOp>(opType, expResult, signA),
339-
expResult);
340327

328+
// First, we select between the exp value and the adjusted value for odd
329+
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
341330
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
342331
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
343332
Value zeroCheck =
344333
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
345334
Value finalResult =
346-
b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, signedExpResult);
335+
b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
347336
rewriter.replaceOp(op, finalResult);
348337
return success();
349338
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,12 @@ func.func @roundf_func(%a: f32) -> f32 {
205205
func.func @powf_func(%a: f64, %b: f64) -> f64 {
206206
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
207207
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
208-
// CHECK-DAG: [[CSTNEG1:%.+]] = arith.constant -1.000000e+00
209-
// CHECK-DAG: [[CSTTWO:%.+]] = arith.constant 2.000000e+00
210-
// CHECK: [[ABSA:%.+]] = math.absf [[ARG0]]
211-
// CHECK: [[ISNEG:%.+]] = arith.cmpf olt, [[ARG0]], [[CST0]]
212-
// CHECK: [[SIGNA:%.+]] = arith.select [[ISNEG]], [[CSTNEG1]], [[CST1]]
213-
// CHECK: [[LOGA:%.+]] = math.log [[ABSA]]
208+
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
214209
// CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
215210
// CHECK: [[EXP:%.+]] = math.exp [[MULB]]
216-
// CHECK: [[REM:%.+]] = arith.remf [[ARG1]], [[CSTTWO]]
217-
// CHECK: [[CMPF:%.+]] = arith.cmpf one, [[REM]], [[CST0]]
218-
// CHECK: [[ABMUL:%.+]] = arith.mulf [[EXP]], [[SIGNA]]
219-
// CHECK: [[SEL0:%.+]] = arith.select [[CMPF]], [[ABMUL]], [[EXP]]
220-
// CHECK: [[CMPF2:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
221-
// CHECK: [[SEL1:%.+]] = arith.select [[CMPF2]], [[CST1]], [[SEL0]]
222-
// CHECK: return [[SEL1]]
211+
// CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
212+
// CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
213+
// CHECK: return [[SEL]]
223214
%ret = math.powf %a, %b : f64
224215
return %ret : f64
225216
}
@@ -601,24 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
601592
return %2 : tensor<8xf32>
602593
}
603594
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
604-
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
605-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
606595
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
607596
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
608597
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
609-
// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : tensor<8xf32>
610-
// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
611-
// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : tensor<8xi1>, tensor<8xf32>
612-
// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : tensor<8xf32>
613-
// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
614-
// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : tensor<8xf32>
615-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
616-
// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
617-
// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : tensor<8xf32>
618-
// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : tensor<8xi1>, tensor<8xf32>
619-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
620-
// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : tensor<8xi1>, tensor<8xf32>
621-
// CHECK: return %[[SEL1]]
598+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
599+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
600+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
601+
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
602+
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
603+
// CHECK: return %[[SEL]]
622604
// -----
623605

624606
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -627,24 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
627609
return %2 : f32
628610
}
629611
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
630-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
631-
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
632612
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
633613
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
634614
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
635-
// CHECK: %[[ABSA:.*]] = math.absf %[[ARG0]] : f32
636-
// CHECK: %[[ISNEG:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
637-
// CHECK: %[[SIGNA:.*]] = arith.select %[[ISNEG]], %[[CSTNEG1]], %[[CST1]] : f32
638-
// CHECK: %[[LOGA:.*]] = math.log %[[ABSA]] : f32
639-
// CHECK: %[[MULA:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
640-
// CHECK: %[[EXPA:.*]] = math.exp %[[MULA]] : f32
641-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
642-
// CHECK: %[[CMPF:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
643-
// CHECK: %[[ABMUL:.*]] = arith.mulf %[[EXPA]], %[[SIGNA]] : f32
644-
// CHECK: %[[SEL0:.*]] = arith.select %[[CMPF]], %[[ABMUL]], %[[EXPA]] : f32
645-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
646-
// CHECK: %[[SEL1:.*]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL0]] : f32
647-
// CHECK: return %[[SEL1]] : f32
615+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
616+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
617+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
618+
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
619+
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
620+
// CHECK: return %[[SEL]] : f32
648621

649622
// -----
650623

mlir/test/mlir-runner/test-expand-math-approx.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,6 @@ func.func @powf() {
202202
%a_p = arith.constant 2.0 : f64
203203
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
204204

205-
// CHECK-NEXT: -27
206-
%b = arith.constant -3.0 : f64
207-
%b_p = arith.constant 3.0 : f64
208-
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
209-
210205
// CHECK-NEXT: 2.343
211206
%c = arith.constant 2.343 : f64
212207
%c_p = arith.constant 1.000 : f64

0 commit comments

Comments
 (0)