-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Lower math.powf(x, 3.0)
to x * x * x
.
#127256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <[email protected]>
@llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) Changes
It turns out that code in the wild has been relying on This PR adds such a lowering for There needs to be a wider project to stop altogether using Full diff: https://github.com/llvm/llvm-project/pull/127256.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
auto &sem =
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
+ auto mulf = [&](Value x, Value y) -> Value {
+ return b.create<arith::MulFOp>(x, y);
+ };
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
if (valueB.isZero()) {
// a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
}
if (valueB.isExactlyValue(2.0)) {
// a^2 -> a * a
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
- rewriter.replaceOp(op, mul);
+ rewriter.replaceOp(op, mulf(operandA, operandA));
return success();
}
if (valueB.isExactlyValue(-2.0)) {
// a^(-2) -> 1 / (a * a)
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, mul);
+ Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
rewriter.replaceOp(op, div);
return success();
}
+ if (valueB.isExactlyValue(3.0)) {
+ rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+ return success();
+ }
}
Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
return %ret : f64
}
+// CHECK-LABEL: func @powf_func_three
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+ // CHECK: return %[[MUL2]] : f64
+ %b = arith.constant 3.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
// -----
// CHECK-LABEL: func.func @roundeven64
|
@llvm/pr-subscribers-mlir-math Author: Benoit Jacob (bjacob) Changes
It turns out that code in the wild has been relying on This PR adds such a lowering for There needs to be a wider project to stop altogether using Full diff: https://github.com/llvm/llvm-project/pull/127256.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index d7953719d44b5..23356d752146d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -325,6 +325,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
auto &sem =
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
APFloat valueB(sem);
+ auto mulf = [&](Value x, Value y) -> Value {
+ return b.create<arith::MulFOp>(x, y);
+ };
if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
if (valueB.isZero()) {
// a^0 -> 1
@@ -358,19 +361,21 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
}
if (valueB.isExactlyValue(2.0)) {
// a^2 -> a * a
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
- rewriter.replaceOp(op, mul);
+ rewriter.replaceOp(op, mulf(operandA, operandA));
return success();
}
if (valueB.isExactlyValue(-2.0)) {
// a^(-2) -> 1 / (a * a)
- Value mul = b.create<arith::MulFOp>(operandA, operandA);
Value one =
createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
- Value div = b.create<arith::DivFOp>(one, mul);
+ Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
rewriter.replaceOp(op, div);
return success();
}
+ if (valueB.isExactlyValue(3.0)) {
+ rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
+ return success();
+ }
}
Value logA = b.create<math::LogOp>(operandA);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index f39d1a7a6dc50..1fdfb854325b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -285,6 +285,17 @@ func.func @powf_func_negtwo(%a: f64) -> f64{
return %ret : f64
}
+// CHECK-LABEL: func @powf_func_three
+// CHECK-SAME: (%[[ARG0:.+]]: f64) -> f64
+func.func @powf_func_three(%a: f64) -> f64{
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[ARG0]], %[[ARG0]] : f64
+ // CHECK: %[[MUL2:.+]] = arith.mulf %[[MUL]], %[[ARG0]] : f64
+ // CHECK: return %[[MUL2]] : f64
+ %b = arith.constant 3.0 : f64
+ %ret = math.powf %a, %b : f64
+ return %ret : f64
+}
+
// -----
// CHECK-LABEL: func.func @roundeven64
|
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until llvm#126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`. It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: iree-org/iree#19996 This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway. There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that. Signed-off-by: Benoit Jacob <[email protected]>
math.powf(x, y)
never really supported negative values ofx
, but that was unclear (happened to work for some values ofy
) until #126338 was merged yesterday and lowered it to the usualexp(y * log(x))
outside of a few special exponent values, such asy == 2.0
lowering tox * x
.It turns out that code in the wild has been relying on
math.powf(x, y)
with negativex
for some integral values ofy
for which a lowering to muls was intended: iree-org/iree#19996This PR adds such a lowering for
y == 3.0
. It "fixes" such cases, and it is a more efficient lowering anyway.There needs to be a wider project to stop altogether using
powf
with negativex
, usemath.fpowi
for that.