Skip to content

Commit 4f9359c

Browse files
committed
[mlir][math] Expand powfI operation for constant power operand.
1 parent 89bae85 commit 4f9359c

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
3636
void populateExpandCeilFPattern(RewritePatternSet &patterns);
3737
void populateExpandExp2FPattern(RewritePatternSet &patterns);
3838
void populateExpandPowFPattern(RewritePatternSet &patterns);
39+
void populateExpandFPowIPattern(RewritePatternSet &patterns);
3940
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4041
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
4142
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,44 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202202
rewriter.replaceOp(op, ret);
203203
return success();
204204
}
205+
206+
// Convert `math.fpowi` to a series of `arith.mulf` operations.
207+
// If the power is negative, we divide one by the result.
208+
static LogicalResult convertFPowIOp(math::FPowIOp op,
209+
PatternRewriter &rewriter) {
210+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211+
Value base = op.getOperand(0);
212+
Value power = op.getOperand(1);
213+
Type baseType = base.getType();
214+
215+
Attribute cstAttr;
216+
if (!matchPattern(power, m_Constant(&cstAttr)))
217+
return failure();
218+
219+
auto iAttr = dyn_cast<SplatElementsAttr>(cstAttr);
220+
if (!iAttr)
221+
return failure();
222+
223+
int64_t powerInt = iAttr.getSplatValue<int64_t>();
224+
bool isNegative = powerInt < 0;
225+
int64_t absPower = std::abs(powerInt);
226+
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
227+
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
228+
229+
while (absPower > 0) {
230+
if (absPower & 1)
231+
res = b.create<arith::MulFOp>(baseType, base, res);
232+
absPower >>= 1;
233+
base = b.create<arith::MulFOp>(baseType, base, base);
234+
}
235+
236+
if (isNegative)
237+
res = b.create<arith::DivFOp>(baseType, one, res);
238+
239+
rewriter.replaceOp(op, res);
240+
return success();
241+
}
242+
205243
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206244
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
207245
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +555,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517555
patterns.add(convertPowfOp);
518556
}
519557

558+
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
559+
patterns.add(convertFPowIOp);
560+
}
561+
520562
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
521563
patterns.add(convertRoundOp);
522564
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,60 @@ func.func @roundeven16(%arg: f16) -> f16 {
511511
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
512512

513513
// CHECK: return %[[COPYSIGN]] : f16
514+
515+
// -----
516+
517+
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518+
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519+
%1 = arith.constant dense<-3> : tensor<8xi64>
520+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521+
return %2 : tensor<8xf32>
522+
}
523+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
524+
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
525+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
526+
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
527+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
528+
// CHECK: return %[[INV]] : tensor<8xf32>
529+
530+
// -----
531+
532+
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
533+
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
534+
%1 = arith.constant dense<-4> : tensor<8xi64>
535+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
536+
return %2 : tensor<8xf32>
537+
}
538+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
539+
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
540+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
541+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
542+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
543+
// CHECK: return %[[INV]] : tensor<8xf32>
544+
545+
// -----
546+
547+
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
548+
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
549+
%1 = arith.constant dense<5> : tensor<8xi64>
550+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
551+
return %2 : tensor<8xf32>
552+
}
553+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
554+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
555+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
556+
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
557+
// CHECK: return %[[PW5]] : tensor<8xf32>
558+
559+
// -----
560+
561+
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
562+
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
563+
%1 = arith.constant dense<4> : tensor<8xi64>
564+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
565+
return %2 : tensor<8xf32>
566+
}
567+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
568+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
569+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
570+
// CHECK: return %[[PW4]] : tensor<8xf32>

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
4646
populateExpandFloorFPattern(patterns);
4747
populateExpandCeilFPattern(patterns);
4848
populateExpandPowFPattern(patterns);
49+
populateExpandFPowIPattern(patterns);
4950
populateExpandRoundFPattern(patterns);
5051
populateExpandRoundEvenPattern(patterns);
5152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)