Skip to content

[mlir][math]Update convertPowfOp ExpandPatterns.cpp #124402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2025
Merged

Conversation

ita9naiwa
Copy link
Contributor

@ita9naiwa ita9naiwa commented Jan 25, 2025

The current implementation of convertPowfOp requires a calculation of a * a but, max<fp16> ~= 65,504, and if a is about 16, it will overflow so get INF in fp8 or fp16 easily.

Remove support when a < 0. Overhead of handling negative value of a is large and easy to overflow;

@ita9naiwa ita9naiwa changed the title Update convertPowfOp ExpandPatterns.cpp [mlir][math]Update convertPowfOp ExpandPatterns.cpp Jan 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 25, 2025

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Hyunsung Lee (ita9naiwa)

Changes

The current implementation of convertPowfOp requires a calculation of a * a but, max<fp16> ~= 65,504 and if a is about 16, it will overflow so get INF in fp8 or fp16 easily.

Instead, take the sign of a and expand it.

or, instead of this approach, we also proceed only casting to a wider type(e.g., fp32, or fp64)


Full diff: https://github.com/llvm/llvm-project/pull/124402.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+27-18)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..314b5b30202064 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,34 +317,43 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   Value operandA = op.getOperand(0);
   Value operandB = op.getOperand(1);
   Type opType = operandA.getType();
+
+  // Constants
   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
-  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
-  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
-  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+
+  // Compute |a| (absolute value of operandA)
+  Value absA = b.create<math::AbsFOp>(opType, operandA);
+
+  // Compute sign(a) as -1.0 if a < 0, else 1.0
+  Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+  Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
 
-  Value logA = b.create<math::LogOp>(opType, opASquared);
-  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+  // Compute ln(|a|)
+  Value logA = b.create<math::LogOp>(opType, absA);
+
+  // Compute b * ln(|a|)
+  Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+
+  // Compute exp(b * ln(|a|))
   Value expResult = b.create<math::ExpOp>(opType, mult);
-  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
-  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
-  Value negCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
-  Value oddPower =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
-  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+  Value logSign = b.create<math::LogOp>(opType, signA);
+  Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
+  Value signPow = b.create<math::ExpOp>(opType, signMult);
+
+  Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
 
   // First, we select between the exp value and the adjusted value for odd
   // powers of negatives. Then, we ensure that one is produced if `b` is zero.
   // This corresponds to `libm` behavior, even for `0^0`. Without this check,
   // `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
-  Value zeroCheck =
-      b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
-  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
-                                        expResult);
-  res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
-  rewriter.replaceOp(op, res);
+  Value zeroCheck = 
+    b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+  Value finalResult = 
+    b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+  rewriter.replaceOp(op, finalResult);
   return success();
 }
 

Copy link

github-actions bot commented Jan 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ita9naiwa ita9naiwa marked this pull request as draft January 25, 2025 11:28
@ita9naiwa ita9naiwa marked this pull request as ready for review January 25, 2025 12:37
@ita9naiwa
Copy link
Contributor Author

I'm sorry, the basic logic is right I think, but I'm still not getting used to lit and other llvm/mlir test suites. expect 1-2 days to finish test cases.

@ita9naiwa
Copy link
Contributor Author

Hi, it's ready for review!

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic LGTM, I'll let others (maybe @kuhar ) if this is the most efficient way to do this.

Comment on lines 314 to 315
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|))
// * sign(a)^b
Copy link
Contributor

@bjacob bjacob Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expected math.powf(a, b) to be restricted to the case of a > 0. It is not currently mentioned in the documentation, but I believe that that is an omission.

The formula in this comment, a ^ b = e^(b * ln |a|) * sign(a)^b, is applicable in two separate cases:

  1. It is correct whenever a > 0, in which case it boils down to the usual a ^ b = e^(b * ln |a|).
  2. It is correct when b is an integer, in which case sign(a)^b is well-defined even when sign(a) == -1.

Outside of the above cases, that is, when a < 0 and b is not integral, the problem is that a^b would have to be a non-real complex number, and moreover, it would be determined only up to a complex factor depending on the precise value of b. For example, when a = -1 and b = 0.5, then a^b would be +/- i. When a = -1 and b = 0.25, then a^b would be determined only among the 4 values {1, i, -1, -i}.

The code in this PR is actually implementing something different. It has to, since it can't produce complex numbers; that makes the above comment an imperfect description of what the code actually does.

What the code in this PR does is that it evaluates remf(b, 2.0). When that returns exactly remf(b, 2.0) == 0.0, it returns just a ^ b = e^(b * ln |a|). Otherwise, when remf(b, 2.0) != 0.0, it returns a ^ b = - e^(b * ln |a|).

The problems with that are:

  1. This is unreliable for large values of b. First because remf may be inexact, and then, for even larger values of b or for narrower floating-point types, because when b is meant to encode a large integer, it may be rounded to the nearest representable value which may not be exactly integral. In either case, remf(b, 2.0) may fail to be exactly 0.0 in cases where it was expected to be.
  2. remf is expensive, so this attempt at supporting a < 0 is a cost on all users of math.powf. The branch-less implementation implies that the cost is paid even when a > 0 even though the result of the computation is effectively unused in that case.

Summary: My recommendation: restrict math.powf to the case where a > 0. Correspondingly simplify the docs, the above comment, and the lowering. The remf goes away.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: My recommendation: restrict math.powf to the case where a > 0. Correspondingly simplify the docs, the above comment, and the lowering. The remf goes away.

I agree with that, but restricting math.powf to a > 0 introduces a potential compatibility issue.

What do you think is the best way to move forward?
@bjacob

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that it was always meant to be restricted to a>0 and it is only an omission in the docs. So my recommendation is to just go ahead with that and see what breaks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the previous implementation actually does (exp((b / 2) * log(a^2)) and I think this trick is introduced to handle a < 0 case,

But yes, I believe that math.powf should support only the case where a>0 and a == 0.

Thanks for the thoughtful review @bjacob !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@bjacob bjacob Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging that history. I have no idea what the previous implementation tried to achieve by this scaling by a factor of 2 (EDIT, now that I read the PR description, it really thought that it was achieving a^b for a<0 !), but here is what it actually achieved:

First, notice that log(a^2) == log( |a| ^2) == 2 * log ( |a| ).

Therefore, exp((b / 2) * log(a^2)) == exp ( b / 2 * 2 * log ( |a| ).

Therefore, exp((b / 2) * log(a^2)) == exp ( b * log ( |a| ).

Therefore, exp((b / 2) * log(a^2)) == |a| ^ b.

Thus, the old implementation was evaluating powf( |a| , b ) instead of powf(a, b).

Apparently, no one noticed :-)

I believe that the reason why no one noticed is that no one was using a < 0.

So I really think that you can go ahead and restrict support to a > 0 officially.

@ita9naiwa
Copy link
Contributor Author

I updated upon your review! @bjacob

@ita9naiwa
Copy link
Contributor Author

could someone merge this? I have no write permission.

@bjacob bjacob merged commit 3a33775 into llvm:main Jan 29, 2025
8 checks passed
@hanhanW
Copy link
Contributor

hanhanW commented Jan 29, 2025

Thanks for working on this, and thanks for reviews from @bjacob and @Groverkss !

hanhanW added a commit to iree-org/llvm-project that referenced this pull request Feb 6, 2025
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Feb 6, 2025
hanhanW added a commit that referenced this pull request Feb 6, 2025
hanhanW added a commit that referenced this pull request Feb 6, 2025
…6063)

Reverts #124402

It breaks an integration test in downstream project (i.e., IREE), which
produces NANs. Talked to the author @ita9naiwa, and we agree to reland
the PR after we find the issue.
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Feb 6, 2025
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 6, 2025
….cpp`" (#126063)

Reverts llvm/llvm-project#124402

It breaks an integration test in downstream project (i.e., IREE), which
produces NANs. Talked to the author @ita9naiwa, and we agree to reland
the PR after we find the issue.
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Feb 6, 2025
ita9naiwa added a commit to ita9naiwa/llvm-project that referenced this pull request Feb 8, 2025
The current implementation of `convertPowfOp` requires a calculation of
`a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will
overflow so get INF in fp8 or fp16 easily.


Remove support when `a < 0`. Overhead of handling negative value of `a`
is large and easy to overflow;

- related issue in iree:
iree-org/iree#15936
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…m#126063)

Reverts llvm#124402

It breaks an integration test in downstream project (i.e., IREE), which
produces NANs. Talked to the author @ita9naiwa, and we agree to reland
the PR after we find the issue.
hanhanW pushed a commit that referenced this pull request Feb 13, 2025
Related: #124402

- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
  - thus drop `a < 0` case support

However, some special cases are being used such as:
  - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
  - convert those special cases into simpler ops.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Related: llvm#124402

- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
  - thus drop `a < 0` case support

However, some special cases are being used such as:
  - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
  - convert those special cases into simpler ops.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Related: llvm#124402

- change inefficient implementation of `powf(a, b)` to handle `a < 0`
case
  - thus drop `a < 0` case support

However, some special cases are being used such as:
  - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2`
  - convert those special cases into simpler ops.
oneseer pushed a commit to oneseer/llvm that referenced this pull request May 24, 2025
…6063)

Reverts llvm/llvm-project#124402

It breaks an integration test in downstream project (i.e., IREE), which
produces NANs. Talked to the author @ita9naiwa, and we agree to reland
the PR after we find the issue.

patch.cherry: true
patch.metadata.original_sha: c9d0a46
patch.platforms: chromiumos
patch.version_range.from: 563468
patch.version_range.until: 564416

Change-Id: I33ef323aac2e00caf2ca9ef20b12e0003e45632e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants