-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math]Update convertPowfOp
ExpandPatterns.cpp
#124402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
convertPowfOp
ExpandPatterns.cpp
convertPowfOp
ExpandPatterns.cpp
@llvm/pr-subscribers-mlir-math @llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) ChangesThe current implementation of Instead, take the sign of or, instead of this approach, we also proceed only casting to a wider type(e.g., fp32, or fp64)
Full diff: https://github.com/llvm/llvm-project/pull/124402.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f6..314b5b30202064 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -317,34 +317,43 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
+
+ // Constants
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
- Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
+ Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
+
+ // Compute |a| (absolute value of operandA)
+ Value absA = b.create<math::AbsFOp>(opType, operandA);
+
+ // Compute sign(a) as -1.0 if a < 0, else 1.0
+ Value isNegative = b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
+ Value signA = b.create<arith::SelectOp>(op->getLoc(), isNegative, negOne, one);
- Value logA = b.create<math::LogOp>(opType, opASquared);
- Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+ // Compute ln(|a|)
+ Value logA = b.create<math::LogOp>(opType, absA);
+
+ // Compute b * ln(|a|)
+ Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
+
+ // Compute exp(b * ln(|a|))
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
- Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value negCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value oddPower =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
+ Value logSign = b.create<math::LogOp>(opType, signA);
+ Value signMult = b.create<arith::MulFOp>(opType, operandB, logSign);
+ Value signPow = b.create<math::ExpOp>(opType, signMult);
+
+ Value resultWithSign = b.create<arith::MulFOp>(opType, expResult, signPow);
// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
// This corresponds to `libm` behavior, even for `0^0`. Without this check,
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
- Value zeroCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
- expResult);
- res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
- rewriter.replaceOp(op, res);
+ Value zeroCheck =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
+ Value finalResult =
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, resultWithSign);
+ rewriter.replaceOp(op, finalResult);
return success();
}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
I'm sorry, the basic logic is right I think, but I'm still not getting used to |
Hi, it's ready for review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic LGTM, I'll let others (maybe @kuhar ) if this is the most efficient way to do this.
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(|a|)) | ||
// * sign(a)^b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expected math.powf(a, b)
to be restricted to the case of a > 0
. It is not currently mentioned in the documentation, but I believe that that is an omission.
The formula in this comment, a ^ b = e^(b * ln |a|) * sign(a)^b
, is applicable in two separate cases:
- It is correct whenever
a > 0
, in which case it boils down to the usuala ^ b = e^(b * ln |a|)
. - It is correct when
b
is an integer, in which casesign(a)^b
is well-defined even whensign(a) == -1
.
Outside of the above cases, that is, when a < 0
and b
is not integral, the problem is that a^b
would have to be a non-real complex number, and moreover, it would be determined only up to a complex factor depending on the precise value of b
. For example, when a = -1
and b = 0.5
, then a^b
would be +/- i
. When a = -1
and b = 0.25
, then a^b
would be determined only among the 4 values {1, i, -1, -i}
.
The code in this PR is actually implementing something different. It has to, since it can't produce complex numbers; that makes the above comment an imperfect description of what the code actually does.
What the code in this PR does is that it evaluates remf(b, 2.0)
. When that returns exactly remf(b, 2.0) == 0.0
, it returns just a ^ b = e^(b * ln |a|)
. Otherwise, when remf(b, 2.0) != 0.0
, it returns a ^ b = - e^(b * ln |a|)
.
The problems with that are:
- This is unreliable for large values of
b
. First becauseremf
may be inexact, and then, for even larger values ofb
or for narrower floating-point types, because whenb
is meant to encode a large integer, it may be rounded to the nearest representable value which may not be exactly integral. In either case,remf(b, 2.0)
may fail to be exactly0.0
in cases where it was expected to be. remf
is expensive, so this attempt at supportinga < 0
is a cost on all users ofmath.powf
. The branch-less implementation implies that the cost is paid even whena > 0
even though the result of the computation is effectively unused in that case.
Summary: My recommendation: restrict math.powf
to the case where a > 0
. Correspondingly simplify the docs, the above comment, and the lowering. The remf
goes away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary: My recommendation: restrict math.powf to the case where a > 0. Correspondingly simplify the docs, the above comment, and the lowering. The remf goes away.
I agree with that, but restricting math.powf
to a > 0 introduces a potential compatibility issue.
What do you think is the best way to move forward?
@bjacob
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My guess is that it was always meant to be restricted to a>0 and it is only an omission in the docs. So my recommendation is to just go ahead with that and see what breaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the previous implementation actually does (exp((b / 2) * log(a^2))
and I think this trick is introduced to handle a < 0
case,
But yes, I believe that math.powf
should support only the case where a>0
and a == 0
.
Thanks for the thoughtful review @bjacob !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging that history. I have no idea what the previous implementation tried to achieve by this scaling by a factor of 2 (EDIT, now that I read the PR description, it really thought that it was achieving a^b for a<0
!), but here is what it actually achieved:
First, notice that log(a^2) == log( |a| ^2) == 2 * log ( |a| )
.
Therefore, exp((b / 2) * log(a^2)) == exp ( b / 2 * 2 * log ( |a| )
.
Therefore, exp((b / 2) * log(a^2)) == exp ( b * log ( |a| )
.
Therefore, exp((b / 2) * log(a^2)) == |a| ^ b
.
Thus, the old implementation was evaluating powf( |a| , b )
instead of powf(a, b)
.
Apparently, no one noticed :-)
I believe that the reason why no one noticed is that no one was using a < 0
.
So I really think that you can go ahead and restrict support to a > 0
officially.
I updated upon your review! @bjacob |
could someone merge this? I have no write permission. |
Thanks for working on this, and thanks for reviews from @bjacob and @Groverkss ! |
)" This reverts commit 3a33775.
…6063) Reverts #124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.
….cpp`" (#126063) Reverts llvm/llvm-project#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.
The current implementation of `convertPowfOp` requires a calculation of `a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will overflow so get INF in fp8 or fp16 easily. Remove support when `a < 0`. Overhead of handling negative value of `a` is large and easy to overflow; - related issue in iree: iree-org/iree#15936
…m#126063) Reverts llvm#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue.
Related: #124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
Related: llvm#124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
Related: llvm#124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
…6063) Reverts llvm/llvm-project#124402 It breaks an integration test in downstream project (i.e., IREE), which produces NANs. Talked to the author @ita9naiwa, and we agree to reland the PR after we find the issue. patch.cherry: true patch.metadata.original_sha: c9d0a46 patch.platforms: chromiumos patch.version_range.from: 563468 patch.version_range.until: 564416 Change-Id: I33ef323aac2e00caf2ca9ef20b12e0003e45632e
The current implementation of
convertPowfOp
requires a calculation ofa * a
but, max<fp16> ~= 65,504, and ifa
is about 16, it will overflow so get INF in fp8 or fp16 easily.Remove support when
a < 0
. Overhead of handling negative value ofa
is large and easy to overflow;The polynomial approximation for f16 math.powf generates NAN and INF iree-org/iree#15936