-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] Expand powfI operation for constant power operand. #87081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// | ||
//===- ExpandPatterns.cpp - Code to expand various math operations. -------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements expansion of tanh op. | ||
// This file implements expansion of various math operations. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
@@ -23,9 +23,14 @@ | |
using namespace mlir; | ||
|
||
/// Create a float constant. | ||
static Value createFloatConst(Location loc, Type type, double value, | ||
static Value createFloatConst(Location loc, Type type, APFloat value, | ||
OpBuilder &b) { | ||
auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); | ||
bool losesInfo = false; | ||
auto eltType = getElementTypeOrSelf(type); | ||
// Convert double to the given `FloatType` with round-to-nearest-ties-to-even. | ||
value.convert(cast<FloatType>(eltType).getFloatSemantics(), | ||
APFloat::rmNearestTiesToEven, &losesInfo); | ||
auto attr = b.getFloatAttr(eltType, value); | ||
if (auto shapedTy = dyn_cast<ShapedType>(type)) { | ||
return b.create<arith::ConstantOp>(loc, | ||
DenseElementsAttr::get(shapedTy, attr)); | ||
|
@@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value, | |
return b.create<arith::ConstantOp>(loc, attr); | ||
} | ||
|
||
/// Create a float constant. | ||
static Value createFloatConst(Location loc, Type type, double value, | ||
OpBuilder &b) { | ||
return createFloatConst(loc, type, APFloat(value), b); | ||
} | ||
|
||
/// Create an integer constant. | ||
static Value createIntConst(Location loc, Type type, int64_t value, | ||
OpBuilder &b) { | ||
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); | ||
|
@@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { | |
rewriter.replaceOp(op, ret); | ||
return success(); | ||
} | ||
|
||
// Convert `math.fpowi` to a series of `arith.mulf` operations. | ||
// If the power is negative, we divide one by the result. | ||
// If both the base and power are zero, the result is 1. | ||
static LogicalResult convertFPowICstOp(math::FPowIOp op, | ||
PatternRewriter &rewriter) { | ||
ImplicitLocOpBuilder b(op->getLoc(), rewriter); | ||
Value base = op.getOperand(0); | ||
Value power = op.getOperand(1); | ||
Type baseType = base.getType(); | ||
|
||
Attribute cstAttr; | ||
if (!matchPattern(power, m_Constant(&cstAttr))) | ||
return failure(); | ||
|
||
APInt value; | ||
if (!matchPattern(cstAttr, m_ConstantInt(&value))) | ||
return failure(); | ||
|
||
int64_t powerInt = value.getSExtValue(); | ||
bool isNegative = powerInt < 0; | ||
int64_t absPower = std::abs(powerInt); | ||
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); | ||
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter); | ||
|
||
while (absPower > 0) { | ||
if (absPower & 1) | ||
res = b.create<arith::MulFOp>(baseType, base, res); | ||
absPower >>= 1; | ||
base = b.create<arith::MulFOp>(baseType, base, base); | ||
} | ||
|
||
// Make sure not to introduce UB in case of negative power. | ||
if (isNegative) { | ||
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) | ||
.getFloatSemantics(); | ||
Value zero = | ||
createFloatConst(op->getLoc(), baseType, | ||
APFloat::getZero(sem, /*Negative=*/false), rewriter); | ||
Value negZero = | ||
createFloatConst(op->getLoc(), baseType, | ||
APFloat::getZero(sem, /*Negative=*/true), rewriter); | ||
Value posInfinity = | ||
createFloatConst(op->getLoc(), baseType, | ||
APFloat::getInf(sem, /*Negative=*/false), rewriter); | ||
Value negInfinity = | ||
createFloatConst(op->getLoc(), baseType, | ||
APFloat::getInf(sem, /*Negative=*/true), rewriter); | ||
Value zeroEqCheck = | ||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero); | ||
Value negZeroEqCheck = | ||
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero); | ||
res = b.create<arith::DivFOp>(baseType, one, res); | ||
res = | ||
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res); | ||
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity, | ||
res); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens when both the base and the power are zero? I checked the llvm langref and it doesn't mention this case: https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic . The c standard library is more informative here: https://en.cppreference.com/w/cpp/numeric/math/pow#:~:text=in%20math_errhandling.-,If,is%20negative%2C%20a%20domain%20error%20or%20a%20pole%20error%20may%20occur.,-If%20the%20implementation . Because neither There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added this information in the doc comments. I also experimented with PyTorch and Python.
They both give 1 as an output. So we're good to go. go. |
||
|
||
rewriter.replaceOp(op, res); | ||
return success(); | ||
} | ||
|
||
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) | ||
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { | ||
ImplicitLocOpBuilder b(op->getLoc(), rewriter); | ||
|
@@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { | |
patterns.add(convertPowfOp); | ||
} | ||
|
||
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { | ||
patterns.add(convertFPowICstOp); | ||
} | ||
|
||
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { | ||
patterns.add(convertRoundOp); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.