-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] powf(a, b)
drop support when a < 0
#126338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current implementation of `convertPowfOp` requires a calculation of `a * a` but, max\<fp16\> ~= 65,504, and if `a` is about 16, it will overflow so get INF in fp8 or fp16 easily. Remove support when `a < 0`. Overhead of handling negative value of `a` is large and easy to overflow; - related issue in iree: iree-org/iree#15936
@llvm/pr-subscribers-mlir-math @llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) ChangesRelated: #124402
Full diff: https://github.com/llvm/llvm-project/pull/126338.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 3dadf9474cf4f67..235ea38dd87d1c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -17,8 +17,13 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cmath>
using namespace mlir;
@@ -311,7 +316,92 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
return success();
}
-// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Convert Powf(float a, float b) for some special cases
+// where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0
+static LogicalResult convertSpecialPowfOp(math::PowFOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ auto baseType = operandB.getType();
+
+ auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
+ .getFloatSemantics();
+
+ auto valueB = APFloat(sem);
+ if (!matchPattern(operandB, m_ConstantFloat(&valueB))) {
+ // Not a constant, return failure
+ return failure();
+ }
+ float floatValueB = valueB.convertToFloat();
+
+ if (floatValueB == 1.0f) {
+ // a^1 -> a
+ rewriter.replaceOp(op, operandA);
+ return success();
+ }
+
+ if (floatValueB == 0.0) {
+ // a^0 -> 1
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ rewriter.replaceOp(op, one);
+ return success();
+ }
+
+ if (floatValueB == 0.5f) {
+ // a^(1/2) -> sqrt(a)
+ Value sqrt = b.create<math::SqrtOp>(operandA);
+ rewriter.replaceOp(op, sqrt);
+ return success();
+ }
+
+ if (floatValueB == -0.5f) {
+ // a^(-1/2) -> 1 / sqrt(a)
+ Value rsqrt = b.create<math::RsqrtOp>(operandA);
+ rewriter.replaceOp(op, rsqrt);
+ return success();
+ }
+
+ if (floatValueB == -1.0f) {
+ // a^(-1) -> 1 / a
+ Value one =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ Value div = b.create<arith::DivFOp>(one, operandA);
+ rewriter.replaceOp(op, div);
+ return success();
+ }
+
+ // Check if the power is an integer
+ if (floatValueB != std::floor(floatValueB)) {
+ // We don't handle non-integer powers here, return failure
+ return failure();
+ }
+
+ auto sign = std::signbit(floatValueB) ? -1 : 1;
+ auto absIntValueB = std::abs(static_cast<int>(floatValueB));
+
+ auto cstOne =
+ createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
+ auto base = operandA;
+ if (sign == -1) {
+ base = b.create<arith::DivFOp>(cstOne, base);
+ }
+ auto current = base;
+ auto result = cstOne;
+ while (absIntValueB > 0) {
+ if (absIntValueB & 1) {
+ result = b.create<arith::MulFOp>(result, current);
+ }
+ current = b.create<arith::MulFOp>(current, current);
+ absIntValueB >>= 1;
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
+// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+// Restricting a >= 0
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -319,21 +409,10 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
- Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
- Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
- Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
- Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
- Value logA = b.create<math::LogOp>(opType, opASquared);
- Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
+ Value logA = b.create<math::LogOp>(opType, operandA);
+ Value mult = b.create<arith::MulFOp>(opType, operandB, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
- Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
- Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
- Value negCheck =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
- Value oddPower =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
- Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
// First, we select between the exp value and the adjusted value for odd
// powers of negatives. Then, we ensure that one is produced if `b` is zero.
@@ -341,10 +420,9 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
// `exp(0 * ln(0)) = exp(0 *-inf) = exp(-nan) = -nan`.
Value zeroCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, operandB, zero);
- Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
- expResult);
- res = b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, res);
- rewriter.replaceOp(op, res);
+ Value finalResult =
+ b.create<arith::SelectOp>(op->getLoc(), zeroCheck, one, expResult);
+ rewriter.replaceOp(op, finalResult);
return success();
}
@@ -660,6 +738,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
}
void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+ patterns.add(convertSpecialPowfOp);
patterns.add(convertPowfOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6055ed0504c84ca..5b443e9e8d4e78e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -202,25 +202,15 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
-func.func @powf_func(%a: f64, %b: f64) ->f64 {
+func.func @powf_func(%a: f64, %b: f64) -> f64 {
// CHECK-DAG: [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[CST1:%.+]] = arith.constant 1.0
- // CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
- // CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
- // CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
- // CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
- // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
- // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
- // CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
- // CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
- // CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
- // CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
- // CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
- // CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
- // CHECK-DAG: [[SEL1:%.+]] = arith.select [[CMPZERO]], [[CST1]], [[SEL]]
- // CHECK: return [[SEL1]]
+ // CHECK: [[LOGA:%.+]] = math.log [[ARG0]]
+ // CHECK: [[MULB:%.+]] = arith.mulf [[ARG1]], [[LOGA]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULB]]
+ // CHECK: [[CMPF:%.+]] = arith.cmpf oeq, [[ARG1]], [[CST0]]
+ // CHECK: [[SEL:%.+]] = arith.select [[CMPF]], [[CST1]], [[EXP]]
+ // CHECK: return [[SEL]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
@@ -602,26 +592,15 @@ func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> t
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
-// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
-// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
-// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK: return %[[SEL1]] : tensor<8xf32>
-
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
+// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : tensor<8xf32>
+// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : tensor<8xf32>
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
+// CHECK: return %[[SEL]]
// -----
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
@@ -630,25 +609,15 @@ func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
-// CHECK-DAG: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
-// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
-// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
-// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
+// CHECK: %[[LOGA:.*]] = math.log %[[ARG0]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TOFP]], %[[LOGA]] : f32
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
-// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
-// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
-// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
-// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
-// CHECK: %[[CMPZERO:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]]
-// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
-// CHECK: %[[SEL1:.+]] = arith.select %[[CMPZERO]], %[[CST1]], %[[SEL]]
-// CHECK: return %[[SEL1]] : f32
+// CHECK: %[[CMP:.*]] = arith.cmpf oeq, %[[TOFP]], %[[CST0]] : f32
+// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %[[CST1]], %[[EXP]] : f32
+// CHECK: return %[[SEL]] : f32
// -----
diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir
index 106b48a2daea2e3..d1916c28878b97a 100644
--- a/mlir/test/mlir-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir
@@ -202,11 +202,6 @@ func.func @powf() {
%a_p = arith.constant 2.0 : f64
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
- // CHECK-NEXT: -27
- %b = arith.constant -3.0 : f64
- %b_p = arith.constant 3.0 : f64
- call @func_powff64(%b, %b_p) : (f64, f64) -> ()
-
// CHECK-NEXT: 2.343
%c = arith.constant 2.343 : f64
%c_p = arith.constant 1.000 : f64
|
I want this looked at before I start changing lit test cases |
} | ||
|
||
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) | ||
// Restricting a >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this actually checked? This seems to be expanding under this assumption, but always does it. Is this a new assumption on the op that should be documented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry!. I forgot to describe PR in more detail.
- I believe it should be documented, in general,
powf(a, b)
wherea < 0
generally yields NaN and we (as far as I know) aren't able to check it runtime.
- This transform should be applied to some small number of 'b' (e.g., when 'abs(b) < 16')
while (absIntValueB > 0) {
if (absIntValueB & 1) {
result = b.create<arith::MulFOp>(result, current);
}
current = b.create<arith::MulFOp>(current, current);
absIntValueB >>= 1;
}
rewriter.replaceOp(op, result);
return success();
The heuristic number for b is not determined yet. This last case can be dropped if it's not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem is that there are some special use-cases where var a < 0
but const b == some multiple of 2
cc @hanhanW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just drop the comment // Restricting a >= 0
here.
Mathematically, the power operation a^b
, is well-defined in two separate (though overlapping) cases:
- When
a > 0
. In that case,a^b
is defined asexp(b * ln(a))
. - When
b
is an integer. In that case,a^b
is defined asa * ... * a
, (b
times), or the reciprocal of that ifb
is negative.
These two definitions agree in the intersection of these two cases.
Because "power" has inherently that two-mode definition, the MLIR op powf
should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0
.
I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf
ops have been broken outside of the case a > 0
, suggesting that no one was relying on that.
But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)
) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)
).
} | ||
|
||
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) | ||
// Restricting a >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just drop the comment // Restricting a >= 0
here.
Mathematically, the power operation a^b
, is well-defined in two separate (though overlapping) cases:
- When
a > 0
. In that case,a^b
is defined asexp(b * ln(a))
. - When
b
is an integer. In that case,a^b
is defined asa * ... * a
, (b
times), or the reciprocal of that ifb
is negative.
These two definitions agree in the intersection of these two cases.
Because "power" has inherently that two-mode definition, the MLIR op powf
should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0
.
I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf
ops have been broken outside of the case a > 0
, suggesting that no one was relying on that.
But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)
) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)
).
// Check if the power is an integer | ||
if (floatValueB != std::floor(floatValueB)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't try to handle arbitrary integer x
. Just handle the special value 2.0
, and maybe also 3.0 and 4.0 if you want, but that's it. If someone really needs a larger integral exponent to be match, we can always expand this pattern later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's handle few cases for now and document it in the function comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept only |b|=2.0
case and removed other integer cases.
@@ -660,6 +703,7 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { | |||
} | |||
|
|||
void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { | |||
patterns.add(convertSpecialPowfOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for MLIR experts: Here, we want the convertSpecialPowfOp
to have precedence over the convertPowfOp
pattern. Is that ensured by it being added first here? If not, do we need to merge these two patterns to ensure ordering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
This would explicitly give the order we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp
run first.
patterns.add(convertSpecialPowfOp);
patterns.add(convertPowfOp);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hanhanW ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.
This is correct, but I think it is not documented. IMO, we prefer using benefit
to prioritize the patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to give explicit benefit=2!
✅ With the latest revision this PR passed the C/C++ code formatter. |
auto opType = operandA.getType(); | ||
auto baseType = operandB.getType(); | ||
|
||
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, since math.powf
requires float arguments (I just checked MathOps.td), the cast really shouldn't ever fail, so I think you can simply use cast
instead of dyn_cast
. You weren't checking for a null return value from dyn_cast
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
auto opType = operandA.getType(); | ||
auto baseType = operandB.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename opType
to typeA
, and baseType
to typeB
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
// Not a constant, return failure | ||
return failure(); | ||
} | ||
float floatValueB = valueB.convertToFloat(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid converting to C++ float
, as the actual type could have higher precision, so this conversion would be rounding to lesser precision and could end up enabling an incorrect rewrite, e.g. if the type is f64
, then the rewrite from powf(a, 1.0 + 1.0e-10)
into a
would be incorrect as, in f64
, 1.0 + 1.0e-10 != 1.0
, but the rounded floatValueB
is exactly 1.0
.
Can you keep the whole logic in this function in APFloat?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, but this is might be bit verbose so could you check it please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hanhanW , i am not familiar myself with APFloat, can you review that aspect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very familiar with it either. After skimming through the doc, I think we can use isExactlyValue
method. I think we do not have precision issue for 0
, +-1
, +-0.5
, +-2
numbers.
llvm-project/llvm/include/llvm/ADT/APFloat.h
Lines 1420 to 1433 in bee9664
/// We don't rely on operator== working on double values, as | |
/// it returns true for things that are clearly not equal, like -0.0 and 0.0. | |
/// As such, this method can be used to do an exact bit-for-bit comparison of | |
/// two floating point values. | |
/// | |
/// We leave the version with the double argument here because it's just so | |
/// convenient to write "2.0" and the like. Without this function we'd | |
/// have to duplicate its logic everywhere it's called. | |
bool isExactlyValue(double V) const { | |
bool ignored; | |
APFloat Tmp(V); | |
Tmp.convert(getSemantics(), APFloat::rmNearestTiesToEven, &ignored); | |
return bitwiseIsEqual(Tmp); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
I also think that it's might be great to merge |
@ita9naiwa , yes, good idea. That would be simpler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least we need all the special cases are tested in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my apology for that tihs PR got verbose, I made appropriate tests and runs well!!
I added appropriate tests, and I think it's now ready to land! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one question about getFloatSemantics
.
auto &sem = | ||
cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics(); | ||
APFloat valueB(sem); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not pretty sure, but do we really need to get the semantics? I feel that the matcher already does the work for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/ita9naiwa/llvm-project/blob/d5ee522a217359264d35516a5b506149b5d1ab68/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp#L286-L288
other parts the same file explicitly use sem so I followed. It would work without it, should we remove?
// CHECK: [[SQRT:%.+]] = math.sqrt [[ARG0]] : f64 | ||
// CHECK: [[DIV:%.+]] = arith.divf [[CSTONE]], [[SQRT]] : f64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand how this test succeeds: The C++ code in the rewrite is correctly generating a math.rsqrt
, which is indeed better than a math.sqrt + arith.divf
. So how is this test, which requires the latter, passing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
powf decomposed into rsqrt, then rsqrt decompose further into these ops, I don't understand where it occurs and only capture the first transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, I see. Thanks for the explanation.
// CHECK: [[LOGA:%.+]] = math.log [[ARG0]] : f64 | ||
// CHECK: [[MUL:%.+]] = arith.mulf [[ARG1]], [[LOGA]] : f64 | ||
// CHECK: [[EXP:%.+]] = math.exp [[MUL]] : f64 | ||
// CHECK: return [[EXP]] : f64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's follow the %[[XXX:.+]]
style because it is more common in MLIR codebase. One of the benefits is that we do not need to escape [
in the [%[[XXX]]]
case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed other tests which are using %[[VAR:%.+]], I would fix.
// CHECK-LABEL: func @ceilf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
// CHECK-NEXT: return [[ADDF]]
%ret = math.ceil %a : f64
return %ret : f64
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed, thanks now I see how file-check test works.
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until #126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`. It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: iree-org/iree#19996 This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway. There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that. Signed-off-by: Benoit Jacob <[email protected]>
Related: llvm#124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until llvm/llvm-project#126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`. It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: iree-org/iree#19996 This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway. There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that. Signed-off-by: Benoit Jacob <[email protected]>
Pure floating point pow operations no-longer support negative base values (see <llvm/llvm-project#126338>), but many models coming from ONNX use floating point representations of integers as the exponent. This change: 1. matches on constant rank-0 exponents and converts them to scalar constants. 2. matches on constant floating-point scalar exponents and converts them to ints if possible. 3. lowers `Tensor(float)^int` cases to `math.fpowi` Addresses some remaining test failures related to <iree-org/iree#19996>.
Related: llvm#124402 - change inefficient implementation of `powf(a, b)` to handle `a < 0` case - thus drop `a < 0` case support However, some special cases are being used such as: - `a < 0` and `b = 0, b = 0.5, b = 1 or b = 2` - convert those special cases into simpler ops.
`math.powf(x, y)` never really supported negative values of `x`, but that was unclear (happened to work for some values of `y`) until llvm#126338 was merged yesterday and lowered it to the usual `exp(y * log(x))` outside of a few special exponent values, such as y == 2.0` lowering to `x * x`. It turns out that code in the wild has been relying on `math.powf(x, y)` with negative `x` for some integral values of `y` for which a lowering to muls was intended: iree-org/iree#19996 This PR adds such a lowering for `y == 3.0`. It "fixes" such cases, and it is a more efficient lowering anyway. There needs to be a wider project to stop altogether using `powf` with negative `x`, use `math.fpowi` for that. Signed-off-by: Benoit Jacob <[email protected]>
Related: #124402
powf(a, b)
to handlea < 0
casea < 0
case supportHowever, some special cases are being used such as:
a < 0
andb = 0, b = 1 or b % 2 == 0