Skip to content

Commit 3a33775

Browse files
authored
[mlir][math]Update convertPowfOp ExpandPatterns.cpp (llvm#124402)
The current implementation of `convertPowfOp` requires a calculation of `a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will overflow so get INF in fp8 or fp16 easily. Remove support when `a < 0`. Overhead of handling negative value of `a` is large and easy to overflow; - related issue in iree: iree-org/iree#15936
1 parent 3bd3e06 commit 3a33775

File tree

3 files changed

+27
-74
lines changed

3 files changed

+27
-74
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,40 +311,29 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
311311
return success();
312312
}
313313

314-
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
314+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
315+
// Restricting a >= 0
315316
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
316317
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
317318
Value operandA = op.getOperand(0);
318319
Value operandB = op.getOperand(1);
319320
Type opType = operandA.getType();
320321
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
321322
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
322-
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
323-
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
324-
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
325-
Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
326323

327-
Value logA = b.create<math::LogOp>(opType, opASquared);
328-
Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
324+
Value logA = b.create<math::LogOp>(opType, operandA);
325+
Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
329326
Value expResult = b.create<math::ExpOp>(opType, mult);
330-
Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
331-
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
332-
Value negCheck =
333-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
334-
Value oddPower =
335-
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
336-
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
337327

338328
// First, we select between the exp value and the adjusted value for odd
339329
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
340330
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
341331
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
342332
Value zeroCheck =
343333
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
344-
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
345-
expResult);
346-
res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
347-
rewriter.replaceOp(op, res);
334+
Value finalResult =
335+
b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
336+
rewriter.replaceOp(op, finalResult);
348337
return success();
349338
}
350339

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {
202202

203203
// CHECK-LABEL: func @powf_func
204204
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
205-
func.func @powf_func(%a: f64, %b: f64) ->f64 {
205+
func.func @powf_func(%a: f64, %b: f64) -> f64 {
206206
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
207207
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
208-
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
209-
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
210-
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
211-
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
212-
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
213-
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
214-
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
215-
// CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
216-
// CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
217-
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
218-
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
219-
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
220-
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
221-
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
222-
// CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
223-
// CHECK: return [[SEL1]]
208+
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
209+
// CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
210+
// CHECK: [[EXP:%.+]] = math.exp [[MULB]]
211+
// CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
212+
// CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
213+
// CHECK: return [[SEL]]
224214
%ret = math.powf %a, %b : f64
225215
return %ret : f64
226216
}
@@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
602592
return %2 : tensor<8xf32>
603593
}
604594
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
605-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
606-
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
607-
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
608595
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
609-
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
610-
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
611-
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
612-
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
613-
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
614-
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
615-
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
616-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
617-
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
618-
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
619-
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
620-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
621-
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
622-
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
623-
// CHECK: return %[[SEL1]] : tensor<8xf32>
624-
596+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
597+
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
598+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
599+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
600+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
601+
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
602+
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
603+
// CHECK: return %[[SEL]]
625604
// -----
626605

627606
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
630609
return %2 : f32
631610
}
632611
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
633-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
634-
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
635612
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
636613
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
637614
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
638-
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
639-
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
640-
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
641-
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
615+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
616+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
642617
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
643-
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
644-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
645-
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
646-
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
647-
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
648-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
649-
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
650-
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
651-
// CHECK: return %[[SEL1]] : f32
618+
// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
619+
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
620+
// CHECK: return %[[SEL]] : f32
652621

653622
// -----
654623

mlir/test/mlir-runner/test-expand-math-approx.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,6 @@ func.func @powf() {
202202
%a_p = arith.constant 2.0 : f64
203203
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
204204

205-
// CHECK-NEXT: -27
206-
%b = arith.constant -3.0 : f64
207-
%b_p = arith.constant 3.0 : f64
208-
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
209-
210205
// CHECK-NEXT: 2.343
211206
%c = arith.constant 2.343 : f64
212207
%c_p = arith.constant 1.000 : f64

0 commit comments

Comments
 (0)