Skip to content

Commit 10a57f3

Browse files
authored
[mlir][math] Expand powfI operation for constant power operand. (#87081)
-- Convert `math.fpowi` to a series of `arith.mulf` operations. -- If the power is negative, we divide the result by 1.
1 parent cbb27be commit 10a57f3

File tree

4 files changed

+183
-5
lines changed

4 files changed

+183
-5
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
3636
void populateExpandCeilFPattern(RewritePatternSet &patterns);
3737
void populateExpandExp2FPattern(RewritePatternSet &patterns);
3838
void populateExpandPowFPattern(RewritePatternSet &patterns);
39+
void populateExpandFPowIPattern(RewritePatternSet &patterns);
3940
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4041
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
4142
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
1+
//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements expansion of tanh op.
9+
// This file implements expansion of various math operations.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

@@ -23,9 +23,14 @@
2323
using namespace mlir;
2424

2525
/// Create a float constant.
26-
static Value createFloatConst(Location loc, Type type, double value,
26+
static Value createFloatConst(Location loc, Type type, APFloat value,
2727
OpBuilder &b) {
28-
auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
28+
bool losesInfo = false;
29+
auto eltType = getElementTypeOrSelf(type);
30+
// Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31+
value.convert(cast<FloatType>(eltType).getFloatSemantics(),
32+
APFloat::rmNearestTiesToEven, &losesInfo);
33+
auto attr = b.getFloatAttr(eltType, value);
2934
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
3035
return b.create<arith::ConstantOp>(loc,
3136
DenseElementsAttr::get(shapedTy, attr));
@@ -34,7 +39,12 @@ static Value createFloatConst(Location loc, Type type, double value,
3439
return b.create<arith::ConstantOp>(loc, attr);
3540
}
3641

37-
/// Create a float constant.
42+
static Value createFloatConst(Location loc, Type type, double value,
43+
OpBuilder &b) {
44+
return createFloatConst(loc, type, APFloat(value), b);
45+
}
46+
47+
/// Create an integer constant.
3848
static Value createIntConst(Location loc, Type type, int64_t value,
3949
OpBuilder &b) {
4050
auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
@@ -202,6 +212,69 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202212
rewriter.replaceOp(op, ret);
203213
return success();
204214
}
215+
216+
// Convert `math.fpowi` to a series of `arith.mulf` operations.
217+
// If the power is negative, we divide one by the result.
218+
// If both the base and power are zero, the result is 1.
219+
static LogicalResult convertFPowICstOp(math::FPowIOp op,
220+
PatternRewriter &rewriter) {
221+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
222+
Value base = op.getOperand(0);
223+
Value power = op.getOperand(1);
224+
Type baseType = base.getType();
225+
226+
Attribute cstAttr;
227+
if (!matchPattern(power, m_Constant(&cstAttr)))
228+
return failure();
229+
230+
APInt value;
231+
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
232+
return failure();
233+
234+
int64_t powerInt = value.getSExtValue();
235+
bool isNegative = powerInt < 0;
236+
int64_t absPower = std::abs(powerInt);
237+
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
238+
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
239+
240+
while (absPower > 0) {
241+
if (absPower & 1)
242+
res = b.create<arith::MulFOp>(baseType, base, res);
243+
absPower >>= 1;
244+
base = b.create<arith::MulFOp>(baseType, base, base);
245+
}
246+
247+
// Make sure not to introduce UB in case of negative power.
248+
if (isNegative) {
249+
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
250+
.getFloatSemantics();
251+
Value zero =
252+
createFloatConst(op->getLoc(), baseType,
253+
APFloat::getZero(sem, /*Negative=*/false), rewriter);
254+
Value negZero =
255+
createFloatConst(op->getLoc(), baseType,
256+
APFloat::getZero(sem, /*Negative=*/true), rewriter);
257+
Value posInfinity =
258+
createFloatConst(op->getLoc(), baseType,
259+
APFloat::getInf(sem, /*Negative=*/false), rewriter);
260+
Value negInfinity =
261+
createFloatConst(op->getLoc(), baseType,
262+
APFloat::getInf(sem, /*Negative=*/true), rewriter);
263+
Value zeroEqCheck =
264+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
265+
Value negZeroEqCheck =
266+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
267+
res = b.create<arith::DivFOp>(baseType, one, res);
268+
res =
269+
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
270+
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
271+
res);
272+
}
273+
274+
rewriter.replaceOp(op, res);
275+
return success();
276+
}
277+
205278
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206279
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
207280
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +590,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517590
patterns.add(convertPowfOp);
518591
}
519592

593+
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
594+
patterns.add(convertFPowICstOp);
595+
}
596+
520597
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
521598
patterns.add(convertRoundOp);
522599
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,102 @@ func.func @roundeven16(%arg: f16) -> f16 {
511511
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
512512

513513
// CHECK: return %[[COPYSIGN]] : f16
514+
515+
// -----
516+
517+
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518+
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519+
%1 = arith.constant dense<-3> : tensor<8xi64>
520+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521+
return %2 : tensor<8xf32>
522+
}
523+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
524+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
525+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
526+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
527+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
528+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
529+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
530+
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
531+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CST0]] : tensor<8xf32>
532+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[CUBE]], %[[CSTNEG0]] : tensor<8xf32>
533+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
534+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
535+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
536+
// CHECK: return %[[UB2]] : tensor<8xf32>
537+
538+
// -----
539+
540+
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
541+
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
542+
%1 = arith.constant dense<-4> : tensor<8xi64>
543+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
544+
return %2 : tensor<8xf32>
545+
}
546+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
547+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
548+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
549+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
550+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
551+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
552+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
553+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
554+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CST0]] : tensor<8xf32>
555+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[PW4]], %[[CSTNEG0]] : tensor<8xf32>
556+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
557+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
558+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
559+
// CHECK: return %[[UB2]] : tensor<8xf32>
560+
561+
// -----
562+
563+
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
564+
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
565+
%1 = arith.constant dense<5> : tensor<8xi64>
566+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
567+
return %2 : tensor<8xf32>
568+
}
569+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
570+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
571+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
572+
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
573+
// CHECK: return %[[PW5]] : tensor<8xf32>
574+
575+
// -----
576+
577+
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
578+
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
579+
%1 = arith.constant dense<4> : tensor<8xi64>
580+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
581+
return %2 : tensor<8xf32>
582+
}
583+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
584+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
585+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
586+
// CHECK: return %[[PW4]] : tensor<8xf32>
587+
588+
// -----
589+
590+
// CHECK-LABEL: func.func @math_fpowi_even_scalar
591+
func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
592+
%pow = arith.constant 2 : i64
593+
%2 = math.fpowi %0, %pow : f32, i64
594+
return %2 : f32
595+
}
596+
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
597+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
598+
// CHECK: return %[[SQ]] : f32
599+
600+
// -----
601+
602+
// CHECK-LABEL: func.func @math_fpowi_scalar_zero
603+
func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
604+
%pow = arith.constant 0 : i64
605+
%2 = math.fpowi %0, %pow : f32, i64
606+
return %2 : f32
607+
}
608+
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
609+
// CHECK: %[[RET:.*]] = arith.constant 1.000000e+00 : f32
610+
// CHECK: return %[[RET]] : f32
611+
612+
// -----

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
4646
populateExpandFloorFPattern(patterns);
4747
populateExpandCeilFPattern(patterns);
4848
populateExpandPowFPattern(patterns);
49+
populateExpandFPowIPattern(patterns);
4950
populateExpandRoundFPattern(patterns);
5051
populateExpandRoundEvenPattern(patterns);
5152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)