Skip to content

[mlir][math] Expand powfI operation for constant power operand. #87081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
Expand Down
87 changes: 82 additions & 5 deletions mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of tanh op.
// This file implements expansion of various math operations.
//
//===----------------------------------------------------------------------===//

Expand All @@ -23,9 +23,14 @@
using namespace mlir;

/// Create a float constant.
static Value createFloatConst(Location loc, Type type, double value,
static Value createFloatConst(Location loc, Type type, APFloat value,
OpBuilder &b) {
auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
bool losesInfo = false;
auto eltType = getElementTypeOrSelf(type);
// Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
value.convert(cast<FloatType>(eltType).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
auto attr = b.getFloatAttr(eltType, value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return b.create<arith::ConstantOp>(loc,
DenseElementsAttr::get(shapedTy, attr));
Expand All @@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value,
return b.create<arith::ConstantOp>(loc, attr);
}

/// Create a float constant.
static Value createFloatConst(Location loc, Type type, double value,
OpBuilder &b) {
return createFloatConst(loc, type, APFloat(value), b);
}

/// Create an integer constant.
static Value createIntConst(Location loc, Type type, int64_t value,
OpBuilder &b) {
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
Expand Down Expand Up @@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
rewriter.replaceOp(op, ret);
return success();
}

// Convert `math.fpowi` to a series of `arith.mulf` operations.
// If the power is negative, we divide one by the result.
// If both the base and power are zero, the result is 1.
static LogicalResult convertFPowICstOp(math::FPowIOp op,
PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value base = op.getOperand(0);
Value power = op.getOperand(1);
Type baseType = base.getType();

Attribute cstAttr;
if (!matchPattern(power, m_Constant(&cstAttr)))
return failure();

APInt value;
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
return failure();

int64_t powerInt = value.getSExtValue();
bool isNegative = powerInt < 0;
int64_t absPower = std::abs(powerInt);
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);

while (absPower > 0) {
if (absPower & 1)
res = b.create<arith::MulFOp>(baseType, base, res);
absPower >>= 1;
base = b.create<arith::MulFOp>(baseType, base, base);
}

// Make sure not to introduce UB in case of negative power.
if (isNegative) {
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
.getFloatSemantics();
Value zero =
createFloatConst(op->getLoc(), baseType,
APFloat::getZero(sem, /*Negative=*/false), rewriter);
Value negZero =
createFloatConst(op->getLoc(), baseType,
APFloat::getZero(sem, /*Negative=*/true), rewriter);
Value posInfinity =
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/false), rewriter);
Value negInfinity =
createFloatConst(op->getLoc(), baseType,
APFloat::getInf(sem, /*Negative=*/true), rewriter);
Value zeroEqCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
res = b.create<arith::DivFOp>(baseType, one, res);
res =
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
res);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when both the base and the power are zero? I checked the llvm langref and it doesn't mention this case: https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic . The c standard library is more informative here: https://en.cppreference.com/w/cpp/numeric/math/pow#:~:text=in%20math_errhandling.-,If,is%20negative%2C%20a%20domain%20error%20or%20a%20pole%20error%20may%20occur.,-If%20the%20implementation .

Because neither math nor llvm define this, I don't think we have to worry about this, but it would be nice to have a comment that explains these corner cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added this information in the doc comments. I also experimented with PyTorch and Python.

>>> torch.pow(torch.tensor(0.0), 0)
tensor(1.)
>>> 0.0 ** 0
1.0

They both give 1 as an output. So we're good to go. go.


rewriter.replaceOp(op, res);
return success();
}

// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Expand Down Expand Up @@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
patterns.add(convertPowfOp);
}

void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
patterns.add(convertFPowICstOp);
}

void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundOp);
}
Expand Down
99 changes: 99 additions & 0 deletions mlir/test/Dialect/Math/expand-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,102 @@ func.func @roundeven16(%arg: f16) -> f16 {
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16

// CHECK: return %[[COPYSIGN]] : f16

// -----

// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<-3> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CSTNEG0]] : tensor<8xf32>
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
// CHECK: return %[[UB2]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_neg_even_power
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<-4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CST0]] : tensor<8xf32>
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CSTNEG0]] : tensor<8xf32>
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
// CHECK: return %[[UB2]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<5> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
// CHECK: return %[[PW5]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_pos_even_power
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
%1 = arith.constant dense<4> : tensor<8xi64>
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
return %2 : tensor<8xf32>
}
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
// CHECK: return %[[PW4]] : tensor<8xf32>

// -----

// CHECK-LABEL: func.func @math_fpowi_even_scalar
func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
%pow = arith.constant 2 : i64
%2 = math.fpowi %0, %pow : f32, i64
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
// CHECK: return %[[SQ]] : f32

// -----

// CHECK-LABEL: func.func @math_fpowi_scalar_zero
func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
%pow = arith.constant 0 : i64
%2 = math.fpowi %0, %pow : f32, i64
return %2 : f32
}
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
// CHECK: %[[RET:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: return %[[RET]] : f32

// -----
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
populateExpandPowFPattern(patterns);
populateExpandFPowIPattern(patterns);
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand Down