-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] Expand powfI operation for constant power operand. #87081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Prashant Kumar (pashu123) Changes-- Convert Full diff: https://github.com/llvm/llvm-project/pull/87081.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 11b2c7a7afa2f7..e2c513047c77a5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
+void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e1ab9c905447b7..82947d064c9fff 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -202,6 +202,48 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
rewriter.replaceOp(op, ret);
return success();
}
+
+// Convert `math.fpowi` to a series of `arith.mulf` operations.
+// If the power is negative, we divide the result by 1.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operandA = op.getOperand(0);
+ Value operandB = op.getOperand(1);
+ Type opType = operandA.getType();
+ auto conOp =
+ mlir::dyn_cast<mlir::arith::ConstantOp>(operandB.getDefiningOp());
+
+ if (!conOp)
+ return failure();
+
+ auto iAttr = dyn_cast<mlir::SplatElementsAttr>(conOp.getValue());
+
+ if (!iAttr)
+ return failure();
+
+ int64_t power = iAttr.getSplatValue<int64_t>();
+ bool neg = power < 0;
+ int64_t absPower = std::abs(power);
+ Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+ Value res = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+
+ while (absPower > 0) {
+
+ if (absPower & 1)
+ res = b.create<arith::MulFOp>(opType, operandA, res);
+
+ absPower = absPower >> 1;
+ operandA = b.create<arith::MulFOp>(opType, operandA, operandA);
+ }
+
+ if (neg)
+ res = b.create<arith::DivFOp>(opType, one, res);
+
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +559,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertPowfOp);
}
+void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
+ patterns.add(convertFPowIOp);
+}
+
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6326d3a71874b4..01dfe4783cfb69 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -511,3 +511,60 @@ func.func @roundeven16(%arg: f16) -> f16 {
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
// CHECK: return %[[COPYSIGN]] : f16
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
+func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<-3> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
+// CHECK: return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_neg_even_power
+func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<-4> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
+// CHECK: return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
+func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<5> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
+// CHECK: return %[[PW5]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL: func.func @math_fpowi_pos_even_power
+func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+ %1 = arith.constant dense<4> : tensor<8xi64>
+ %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+ return %2 : tensor<8xf32>
+}
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+// CHECK: return %[[PW4]] : tensor<8xf32>
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 7ce8b5a7cfe9b3..97600ad1ebe7a3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
populateExpandPowFPattern(patterns);
+ populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
02ab3da
to
4f9359c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more thing
5917a99
to
3b6559d
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
f10c398
to
9ecbb3d
Compare
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); | ||
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, | ||
res); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when both the base and the power are zero? I checked the llvm langref and it doesn't mention this case: https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic . The c standard library is more informative here: https://en.cppreference.com/w/cpp/numeric/math/pow#:~:text=in%20math_errhandling.-,If,is%20negative%2C%20a%20domain%20error%20or%20a%20pole%20error%20may%20occur.,-If%20the%20implementation .
Because neither math
nor llvm
define this, I don't think we have to worry about this, but it would be nice to have a comment that explains these corner cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added this information in the doc comments. I also experimented with PyTorch and Python.
>>> torch.pow(torch.tensor(0.0), 0)
tensor(1.)
>>> 0.0 ** 0
1.0
They both give 1 as an output. So we're good to go. go.
239cab4
to
60cb5ab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the fixes.
Thanks for the review. Your feedback was very valuable. |
This is tested in llvm via code added: llvm/llvm-project#87081 Fixes #16906
-- Convert
math.fpowi
to a series ofarith.mulf
operations.-- If the power is negative, we divide the result by 1.