Skip to content

Commit 8d72718

Browse files
ita9naiwajoaosaffran
authored andcommitted
[mlir][math] powf(a, b) drop support when a < 0 (llvm#126338)
Related: llvm#124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
1 parent 751a008 commit 8d72718

File tree

3 files changed

+188
-118
lines changed

3 files changed

+188
-118
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/ImplicitLocOpBuilder.h"
2020
#include "mlir/IR/TypeUtilities.h"
2121
#include "mlir/Transforms/DialectConversion.h"
22+
#include "llvm/ADT/APFloat.h"
2223

2324
using namespace mlir;
2425

@@ -311,40 +312,71 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
311312
return success();
312313
}
313314

314-
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
315+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
316+
// Some special cases where b is constant are handled separately:
317+
// when b == 0, or |b| == 0.5, 1.0, or 2.0.
315318
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
316319
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
317320
Value operandA = op.getOperand(0);
318321
Value operandB = op.getOperand(1);
319-
Type opType = operandA.getType();
320-
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
321-
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
322-
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
323-
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
324-
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
325-
Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
326-
327-
Value logA = b.create<math::LogOp>(opType, opASquared);
328-
Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
329-
Value expResult = b.create<math::ExpOp>(opType, mult);
330-
Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
331-
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
332-
Value negCheck =
333-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
334-
Value oddPower =
335-
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
336-
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
337-
338-
// First, we select between the exp value and the adjusted value for odd
339-
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
340-
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
341-
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
342-
Value zeroCheck =
343-
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
344-
Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
345-
expResult);
346-
res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
347-
rewriter.replaceOp(op, res);
322+
auto typeA = operandA.getType();
323+
auto typeB = operandB.getType();
324+
325+
auto &sem =
326+
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
327+
APFloat valueB(sem);
328+
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
329+
if (valueB.isZero()) {
330+
// a^0 -> 1
331+
Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
332+
rewriter.replaceOp(op, one);
333+
return success();
334+
}
335+
if (valueB.isExactlyValue(1.0)) {
336+
// a^1 -> a
337+
rewriter.replaceOp(op, operandA);
338+
return success();
339+
}
340+
if (valueB.isExactlyValue(-1.0)) {
341+
// a^(-1) -> 1 / a
342+
Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
343+
Value div = b.create<arith::DivFOp>(one, operandA);
344+
rewriter.replaceOp(op, div);
345+
return success();
346+
}
347+
if (valueB.isExactlyValue(0.5)) {
348+
// a^(1/2) -> sqrt(a)
349+
Value sqrt = b.create<math::SqrtOp>(operandA);
350+
rewriter.replaceOp(op, sqrt);
351+
return success();
352+
}
353+
if (valueB.isExactlyValue(-0.5)) {
354+
// a^(-1/2) -> 1 / sqrt(a)
355+
Value rsqrt = b.create<math::RsqrtOp>(operandA);
356+
rewriter.replaceOp(op, rsqrt);
357+
return success();
358+
}
359+
if (valueB.isExactlyValue(2.0)) {
360+
// a^2 -> a * a
361+
Value mul = b.create<arith::MulFOp>(operandA, operandA);
362+
rewriter.replaceOp(op, mul);
363+
return success();
364+
}
365+
if (valueB.isExactlyValue(-2.0)) {
366+
// a^(-2) -> 1 / (a * a)
367+
Value mul = b.create<arith::MulFOp>(operandA, operandA);
368+
Value one =
369+
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
370+
Value div = b.create<arith::DivFOp>(one, mul);
371+
rewriter.replaceOp(op, div);
372+
return success();
373+
}
374+
}
375+
376+
Value logA = b.create<math::LogOp>(operandA);
377+
Value mult = b.create<arith::MulFOp>(operandB, logA);
378+
Value expResult = b.create<math::ExpOp>(mult);
379+
rewriter.replaceOp(op, expResult);
348380
return success();
349381
}
350382

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 88 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -201,26 +201,86 @@ func.func @roundf_func(%a: f32) -> f32 {
201201
// -----
202202

203203
// CHECK-LABEL: func @powf_func
204-
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
205-
func.func @powf_func(%a: f64, %b: f64) ->f64 {
206-
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
207-
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
208-
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
209-
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
210-
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
211-
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
212-
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
213-
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
214-
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
215-
// CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
216-
// CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
217-
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
218-
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
219-
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
220-
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
221-
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
222-
// CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
223-
// CHECK: return [[SEL1]]
204+
// CHECK-SAME: (%[[ARG0:.+]]: f64, %[[ARG1:.+]]: f64) -> f64
205+
func.func @powf_func(%a: f64, %b: f64) -> f64 {
206+
// CHECK: %[[LOGA:.+]] = math.log %[[ARG0]] : f64
207+
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG1]], %[[LOGA]] : f64
208+
// CHECK: %[[EXP:.+]] = math.exp %[[MUL]] : f64
209+
// CHECK: return %[[EXP]] : f64
210+
%ret = math.powf %a, %b : f64
211+
return %ret : f64
212+
}
213+
214+
// CHECK-LABEL: func @powf_func_zero
215+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
216+
func.func @powf_func_zero(%a: f64) -> f64{
217+
// CHECK: %[[ONE:.+]] = arith.constant 1.000000e+00 : f64
218+
// CHECK: return %[[ONE]] : f64
219+
%b = arith.constant 0.0 : f64
220+
%ret = math.powf %a, %b : f64
221+
return %ret : f64
222+
}
223+
224+
// CHECK-LABEL: func @powf_func_one
225+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
226+
func.func @powf_func_one(%a: f64) -> f64{
227+
// CHECK: return %[[ARG0]] : f64
228+
%b = arith.constant 1.0 : f64
229+
%ret = math.powf %a, %b : f64
230+
return %ret : f64
231+
}
232+
233+
// CHECK-LABEL: func @powf_func_negone
234+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
235+
func.func @powf_func_negone(%a: f64) -> f64{
236+
// CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
237+
// CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[ARG0]] : f64
238+
// CHECK: return %[[DIV]] : f64
239+
%b = arith.constant -1.0 : f64
240+
%ret = math.powf %a, %b : f64
241+
return %ret : f64
242+
}
243+
244+
// CHECK-LABEL: func @powf_func_half
245+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
246+
func.func @powf_func_half(%a: f64) -> f64{
247+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64
248+
// CHECK: return %[[SQRT]] : f64
249+
%b = arith.constant 0.5 : f64
250+
%ret = math.powf %a, %b : f64
251+
return %ret : f64
252+
}
253+
254+
// CHECK-LABEL: func @powf_func_neghalf
255+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
256+
func.func @powf_func_neghalf(%a: f64) -> f64{
257+
// CHECK: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
258+
// CHECK: %[[SQRT:.+]] = math.sqrt %[[ARG0]] : f64
259+
// CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[SQRT]] : f64
260+
// CHECK: return %[[DIV]] : f64
261+
%b = arith.constant -0.5 : f64
262+
%ret = math.powf %a, %b : f64
263+
return %ret : f64
264+
}
265+
266+
// CHECK-LABEL: func @powf_func_two
267+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
268+
func.func @powf_func_two(%a: f64) -> f64{
269+
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
270+
// CHECK: return %[[MUL]] : f64
271+
%b = arith.constant 2.0 : f64
272+
%ret = math.powf %a, %b : f64
273+
return %ret : f64
274+
}
275+
276+
// CHECK-LABEL: func @powf_func_negtwo
277+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
278+
func.func @powf_func_negtwo(%a: f64) -> f64{
279+
// CHECK-DAG: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
280+
// CHECK-DAG: %[[CSTONE:.+]] = arith.constant 1.000000e+00 : f64
281+
// CHECK: %[[DIV:.+]] = arith.divf %[[CSTONE]], %[[MUL]] : f64
282+
// CHECK: return %[[DIV]] : f64
283+
%b = arith.constant -2.0 : f64
224284
%ret = math.powf %a, %b : f64
225285
return %ret : f64
226286
}
@@ -602,26 +662,11 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
602662
return %2 : tensor<8xf32>
603663
}
604664
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
605-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
606-
// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
607-
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
608-
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
609-
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
610-
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
611-
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
612-
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
613-
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
614-
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
615-
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
616-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
617-
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
618-
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
619-
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
620-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
621-
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
622-
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
623-
// CHECK: return %[[SEL1]] : tensor<8xf32>
624-
665+
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
666+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
667+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
668+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
669+
// CHECK: return %[[EXP]]
625670
// -----
626671

627672
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -630,25 +675,11 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
630675
return %2 : f32
631676
}
632677
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
633-
// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
634-
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
635-
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
636-
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
637678
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
638-
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
639-
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
640-
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
641-
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
679+
// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
680+
// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
642681
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
643-
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
644-
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
645-
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
646-
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
647-
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
648-
// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
649-
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
650-
// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
651-
// CHECK: return %[[SEL1]] : f32
682+
// CHECK: return %[[EXP]] : f32
652683

653684
// -----
654685

mlir/test/mlir-runner/test-expand-math-approx.mlir

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -202,55 +202,62 @@ func.func @powf() {
202202
%a_p = arith.constant 2.0 : f64
203203
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
204204

205-
// CHECK-NEXT: -27
206-
%b = arith.constant -3.0 : f64
207-
%b_p = arith.constant 3.0 : f64
208-
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
209-
210205
// CHECK-NEXT: 2.343
211-
%c = arith.constant 2.343 : f64
212-
%c_p = arith.constant 1.000 : f64
213-
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
206+
%b = arith.constant 2.343 : f64
207+
%b_p = arith.constant 1.000 : f64
208+
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
214209

215210
// CHECK-NEXT: 0.176171
216-
%d = arith.constant 4.25 : f64
217-
%d_p = arith.constant -1.2 : f64
218-
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
211+
%c = arith.constant 4.25 : f64
212+
%c_p = arith.constant -1.2 : f64
213+
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
219214

220215
// CHECK-NEXT: 1
221-
%e = arith.constant 4.385 : f64
222-
%e_p = arith.constant 0.00 : f64
223-
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
216+
%d = arith.constant 4.385 : f64
217+
%d_p = arith.constant 0.00 : f64
218+
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
224219

225220
// CHECK-NEXT: 6.62637
226-
%f = arith.constant 4.835 : f64
227-
%f_p = arith.constant 1.2 : f64
228-
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
221+
%e = arith.constant 4.835 : f64
222+
%e_p = arith.constant 1.2 : f64
223+
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
229224

230225
// CHECK-NEXT: nan
231-
%i = arith.constant 1.0 : f64
232-
%h = arith.constant 0x7fffffffffffffff : f64
233-
call @func_powff64(%i, %h) : (f64, f64) -> ()
226+
%f = arith.constant 1.0 : f64
227+
%f_p = arith.constant 0x7fffffffffffffff : f64
228+
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
234229

235230
// CHECK-NEXT: inf
236-
%j = arith.constant 29385.0 : f64
237-
%j_p = arith.constant 23598.0 : f64
238-
call @func_powff64(%j, %j_p) : (f64, f64) -> ()
231+
%g = arith.constant 29385.0 : f64
232+
%g_p = arith.constant 23598.0 : f64
233+
call @func_powff64(%g, %g_p) : (f64, f64) -> ()
239234

240235
// CHECK-NEXT: -nan
241-
%k = arith.constant 1.0 : f64
242-
%k_p = arith.constant 0xfff0000001000000 : f64
243-
call @func_powff64(%k, %k_p) : (f64, f64) -> ()
236+
%h = arith.constant 1.0 : f64
237+
%h_p = arith.constant 0xfff0000001000000 : f64
238+
call @func_powff64(%h, %h_p) : (f64, f64) -> ()
244239

245240
// CHECK-NEXT: -nan
246-
%l = arith.constant 1.0 : f32
247-
%l_p = arith.constant 0xffffffff : f32
248-
call @func_powff32(%l, %l_p) : (f32, f32) -> ()
241+
%i = arith.constant 1.0 : f32
242+
%i_p = arith.constant 0xffffffff : f32
243+
call @func_powff32(%i, %i_p) : (f32, f32) -> ()
249244

250245
// CHECK-NEXT: 1
251-
%zero = arith.constant 0.0 : f32
252-
call @func_powff32(%zero, %zero) : (f32, f32) -> ()
246+
%j = arith.constant 0.000 : f32
247+
%j_r = math.powf %j, %j : f32
248+
vector.print %j_r : f32
253249

250+
// CHECK-NEXT: 4
251+
%k = arith.constant -2.0 : f32
252+
%k_p = arith.constant 2.0 : f32
253+
%k_r = math.powf %k, %k_p : f32
254+
vector.print %k_r : f32
255+
256+
// CHECK-NEXT: 0.25
257+
%l = arith.constant -2.0 : f32
258+
%l_p = arith.constant -2.0 : f32
259+
%l_r = math.powf %k, %l_p : f32
260+
vector.print %l_r : f32
254261
return
255262
}
256263

0 commit comments

Comments
 (0)