Skip to content

[mlir][math] Expand powfI operation for constant power operand. #87081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2024

Conversation

pashu123
Copy link
Member

-- Convert math.fpowi to a series of arith.mulf operations.
-- If the power is negative, we divide the result by 1.

@llvmbot
Copy link
Member

llvmbot commented Mar 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Prashant Kumar (pashu123)

Changes

-- Convert math.fpowi to a series of arith.mulf operations.
-- If the power is negative, we divide the result by 1.


Full diff: https://github.com/llvm/llvm-project/pull/87081.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1)
  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+46)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+57)
  • (modified) mlir/test/lib/Dialect/Math/TestExpandMath.cpp (+1)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 11b2c7a7afa2f7..e2c513047c77a5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
 void populateExpandPowFPattern(RewritePatternSet &patterns);
+void populateExpandFPowIPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e1ab9c905447b7..82947d064c9fff 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -202,6 +202,48 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
 }
+
+// Convert `math.fpowi` to a series of `arith.mulf` operations.
+// If the power is negative, we divide the result by 1.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Type opType = operandA.getType();
+  auto conOp =
+      mlir::dyn_cast<mlir::arith::ConstantOp>(operandB.getDefiningOp());
+
+  if (!conOp)
+    return failure();
+
+  auto iAttr = dyn_cast<mlir::SplatElementsAttr>(conOp.getValue());
+
+  if (!iAttr)
+    return failure();
+
+  int64_t power = iAttr.getSplatValue<int64_t>();
+  bool neg = power < 0;
+  int64_t absPower = std::abs(power);
+  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+  Value res = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+
+  while (absPower > 0) {
+
+    if (absPower & 1)
+      res = b.create<arith::MulFOp>(opType, operandA, res);
+
+    absPower = absPower >> 1;
+    operandA = b.create<arith::MulFOp>(opType, operandA, operandA);
+  }
+
+  if (neg)
+    res = b.create<arith::DivFOp>(opType, one, res);
+
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
 // Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +559,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
   patterns.add(convertPowfOp);
 }
 
+void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
+  patterns.add(convertFPowIOp);
+}
+
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundOp);
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 6326d3a71874b4..01dfe4783cfb69 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -511,3 +511,60 @@ func.func @roundeven16(%arg: f16) -> f16 {
 // CHECK:     %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
 
 // CHECK: return %[[COPYSIGN]] : f16
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_neg_odd_power
+func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-3> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
+//  CHECK:      return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_neg_even_power
+func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<-4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
+//  CHECK:      return %[[INV]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_pos_odd_power
+func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<5> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:        %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:      return %[[PW5]] : tensor<8xf32>
+
+// -----
+
+// CHECK-LABEL:   func.func @math_fpowi_pos_even_power
+func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
+  %1 = arith.constant dense<4> : tensor<8xi64> 
+  %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
+  return %2 : tensor<8xf32>
+}
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
+//  CHECK:        %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
+//  CHECK:        %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
+//  CHECK:      return %[[PW4]] : tensor<8xf32>
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 7ce8b5a7cfe9b3..97600ad1ebe7a3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
   populateExpandPowFPattern(patterns);
+  populateExpandFPowIPattern(patterns);
   populateExpandRoundFPattern(patterns);
   populateExpandRoundEvenPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

@pashu123 pashu123 force-pushed the math_fpowi branch 2 times, most recently from 02ab3da to 4f9359c Compare March 29, 2024 16:48
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing

@pashu123 pashu123 force-pushed the math_fpowi branch 2 times, most recently from 5917a99 to 3b6559d Compare March 30, 2024 13:23
@pashu123 pashu123 requested a review from kuhar March 30, 2024 13:25
Copy link

github-actions bot commented Mar 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@pashu123 pashu123 force-pushed the math_fpowi branch 3 times, most recently from f10c398 to 9ecbb3d Compare March 31, 2024 07:12
@pashu123 pashu123 requested a review from kuhar March 31, 2024 09:21
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
res);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when both the base and the power are zero? I checked the llvm langref and it doesn't mention this case: https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic . The c standard library is more informative here: https://en.cppreference.com/w/cpp/numeric/math/pow#:~:text=in%20math_errhandling.-,If,is%20negative%2C%20a%20domain%20error%20or%20a%20pole%20error%20may%20occur.,-If%20the%20implementation .

Because neither math nor llvm define this, I don't think we have to worry about this, but it would be nice to have a comment that explains these corner cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added this information in the doc comments. I also experimented with PyTorch and Python.

>>> torch.pow(torch.tensor(0.0), 0)
tensor(1.)
>>> 0.0 ** 0
1.0

They both give 1 as an output. So we're good to go. go.

@pashu123 pashu123 force-pushed the math_fpowi branch 2 times, most recently from 239cab4 to 60cb5ab Compare April 1, 2024 05:34
@pashu123 pashu123 requested a review from kuhar April 1, 2024 05:36
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fixes.

@pashu123
Copy link
Member Author

pashu123 commented Apr 1, 2024

LGTM. Thanks for the fixes.

Thanks for the review. Your feedback was very valuable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants