Skip to content

Commit 9ecbb3d

Browse files
committed
[mlir][math] Expand powfI operation for constant power operand.
1 parent 89bae85 commit 9ecbb3d

File tree

4 files changed

+179
-2
lines changed

4 files changed

+179
-2
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
3636
void populateExpandCeilFPattern(RewritePatternSet &patterns);
3737
void populateExpandExp2FPattern(RewritePatternSet &patterns);
3838
void populateExpandPowFPattern(RewritePatternSet &patterns);
39+
void populateExpandFPowIPattern(RewritePatternSet &patterns);
3940
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4041
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
4142
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2-
//
31
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42
// See https://llvm.org/LICENSE.txt for license information.
53
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
@@ -33,6 +31,16 @@ static Value createFloatConst(Location loc, Type type, double value,
3331

3432
return b.create<arith::ConstantOp>(loc, attr);
3533
}
34+
static Value createFloatConst(Location loc, Type type, APFloat value,
35+
OpBuilder &b) {
36+
auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
37+
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
38+
return b.create<arith::ConstantOp>(loc,
39+
DenseElementsAttr::get(shapedTy, attr));
40+
}
41+
42+
return b.create<arith::ConstantOp>(loc, attr);
43+
}
3644

3745
/// Create a float constant.
3846
static Value createIntConst(Location loc, Type type, int64_t value,
@@ -202,6 +210,70 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202210
rewriter.replaceOp(op, ret);
203211
return success();
204212
}
213+
214+
// Convert `math.fpowi` to a series of `arith.mulf` operations.
215+
// If the power is negative, we divide one by the result.
216+
static LogicalResult convertFPowICstOp(math::FPowIOp op,
217+
PatternRewriter &rewriter) {
218+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
219+
Value base = op.getOperand(0);
220+
Value power = op.getOperand(1);
221+
Type baseType = base.getType();
222+
Value tempBase = op.getOperand(0);
223+
224+
Attribute cstAttr;
225+
if (!matchPattern(power, m_Constant(&cstAttr)))
226+
return failure();
227+
228+
APInt value;
229+
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
230+
return failure();
231+
232+
int64_t powerInt = value.getSExtValue();
233+
bool isNegative = powerInt < 0;
234+
int64_t absPower = std::abs(powerInt);
235+
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
236+
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
237+
238+
auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
239+
.getFloatSemantics();
240+
Value zero =
241+
createFloatConst(op->getLoc(), baseType,
242+
APFloat::getZero(sem, /*Negative=*/false), rewriter);
243+
Value negZero =
244+
createFloatConst(op->getLoc(), baseType,
245+
APFloat::getZero(sem, /*Negative=*/true), rewriter);
246+
Value posInfinity =
247+
createFloatConst(op->getLoc(), baseType,
248+
APFloat::getInf(sem, /*Negative=*/false), rewriter);
249+
Value negInfinity =
250+
createFloatConst(op->getLoc(), baseType,
251+
APFloat::getInf(sem, /*Negative=*/true), rewriter);
252+
253+
while (absPower > 0) {
254+
if (absPower & 1)
255+
res = b.create<arith::MulFOp>(baseType, tempBase, res);
256+
absPower >>= 1;
257+
tempBase = b.create<arith::MulFOp>(baseType, tempBase, tempBase);
258+
}
259+
260+
// Take care of UB in case of negative power.
261+
if (isNegative) {
262+
res = b.create<arith::DivFOp>(baseType, one, res);
263+
Value zeroEqCheck =
264+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, zero);
265+
Value negZeroEqCheck =
266+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, negZero);
267+
res =
268+
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
269+
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
270+
res);
271+
}
272+
273+
rewriter.replaceOp(op, res);
274+
return success();
275+
}
276+
205277
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206278
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
207279
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +589,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517589
patterns.add(convertPowfOp);
518590
}
519591

592+
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
593+
patterns.add(convertFPowICstOp);
594+
}
595+
520596
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
521597
patterns.add(convertRoundOp);
522598
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,102 @@ func.func @roundeven16(%arg: f16) -> f16 {
511511
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
512512

513513
// CHECK: return %[[COPYSIGN]] : f16
514+
515+
// -----
516+
517+
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518+
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519+
%1 = arith.constant dense<-3> : tensor<8xi64>
520+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521+
return %2 : tensor<8xf32>
522+
}
523+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
524+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
525+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
526+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
527+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
528+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
529+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
530+
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
531+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
532+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
533+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
534+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
535+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
536+
// CHECK: return %[[UB2]] : tensor<8xf32>
537+
538+
// -----
539+
540+
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
541+
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
542+
%1 = arith.constant dense<-4> : tensor<8xi64>
543+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
544+
return %2 : tensor<8xf32>
545+
}
546+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
547+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
548+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
549+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
550+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
551+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
552+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
553+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
554+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
555+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
556+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
557+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
558+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
559+
// CHECK: return %[[UB2]] : tensor<8xf32>
560+
561+
// -----
562+
563+
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
564+
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
565+
%1 = arith.constant dense<5> : tensor<8xi64>
566+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
567+
return %2 : tensor<8xf32>
568+
}
569+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
570+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
571+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
572+
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
573+
// CHECK: return %[[PW5]] : tensor<8xf32>
574+
575+
// -----
576+
577+
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
578+
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
579+
%1 = arith.constant dense<4> : tensor<8xi64>
580+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
581+
return %2 : tensor<8xf32>
582+
}
583+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
584+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
585+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
586+
// CHECK: return %[[PW4]] : tensor<8xf32>
587+
588+
// -----
589+
590+
// CHECK-LABEL: func.func @math_fpowi_even_scalar
591+
func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
592+
%pow = arith.constant 2 : i64
593+
%2 = math.fpowi %0, %pow : f32, i64
594+
return %2 : f32
595+
}
596+
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
597+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
598+
// CHECK: return %[[SQ]] : f32
599+
600+
// -----
601+
602+
// CHECK-LABEL: func.func @math_fpowi_scalar_zero
603+
func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
604+
%pow = arith.constant 0 : i64
605+
%2 = math.fpowi %0, %pow : f32, i64
606+
return %2 : f32
607+
}
608+
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
609+
// CHECK: %[[RET:.*]] = arith.constant 1.000000e+00 : f32
610+
// CHECK: return %[[RET]] : f32
611+
612+
// -----

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
4646
populateExpandFloorFPattern(patterns);
4747
populateExpandCeilFPattern(patterns);
4848
populateExpandPowFPattern(patterns);
49+
populateExpandFPowIPattern(patterns);
4950
populateExpandRoundFPattern(patterns);
5051
populateExpandRoundEvenPattern(patterns);
5152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)