Skip to content

Commit 0301bf9

Browse files
authored
[MLIR] Lower math.powf(x, 3.0) to x * x * x. (#127256)
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until #126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`. It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: iree-org/iree#19996 This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway. There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that. Signed-off-by: Benoit Jacob <[email protected]>
1 parent 48c92dd commit 0301bf9

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
325325
auto &sem =
326326
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
327327
APFloat valueB(sem);
328+
auto mulf = [&](Value x, Value y) -> Value {
329+
return b.create<arith::MulFOp>(x, y);
330+
};
328331
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
329332
if (valueB.isZero()) {
330333
// a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
358361
}
359362
if (valueB.isExactlyValue(2.0)) {
360363
// a^2 -> a * a
361-
Value mul = b.create<arith::MulFOp>(operandA, operandA);
362-
rewriter.replaceOp(op, mul);
364+
rewriter.replaceOp(op, mulf(operandA, operandA));
363365
return success();
364366
}
365367
if (valueB.isExactlyValue(-2.0)) {
366368
// a^(-2) -> 1 / (a * a)
367-
Value mul = b.create<arith::MulFOp>(operandA, operandA);
368369
Value one =
369370
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
370-
Value div = b.create<arith::DivFOp>(one, mul);
371+
Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
371372
rewriter.replaceOp(op, div);
372373
return success();
373374
}
375+
if (valueB.isExactlyValue(3.0)) {
376+
rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
377+
return success();
378+
}
374379
}
375380

376381
Value logA = b.create<math::LogOp>(operandA);

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
285285
return %ret : f64
286286
}
287287

288+
// CHECK-LABEL: func @powf_func_three
289+
// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
290+
func.func @powf_func_three(%a: f64) -> f64{
291+
// CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
292+
// CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
293+
// CHECK: return %[[MUL2]] : f64
294+
%b = arith.constant 3.0 : f64
295+
%ret = math.powf %a, %b : f64
296+
return %ret : f64
297+
}
298+
288299
// -----
289300

290301
// CHECK-LABEL: func.func @roundeven64

0 commit comments

Comments
 (0)