Skip to content

Commit 2a04ce2

Browse files
committed
[mlir][math] Expand powfI operation for constant power operand.
1 parent 89bae85 commit 2a04ce2

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
3636
void populateExpandCeilFPattern(RewritePatternSet &patterns);
3737
void populateExpandExp2FPattern(RewritePatternSet &patterns);
3838
void populateExpandPowFPattern(RewritePatternSet &patterns);
39+
void populateExpandFPowIPattern(RewritePatternSet &patterns);
3940
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4041
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
4142
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,48 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202202
rewriter.replaceOp(op, ret);
203203
return success();
204204
}
205+
206+
// Convert `math.fpowi` to a series of `arith.mulf` operations.
207+
// If the power is negative, we divide the result by 1.
208+
static LogicalResult convertFPowIOp(math::FPowIOp op,
209+
PatternRewriter &rewriter) {
210+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211+
Value operandA = op.getOperand(0);
212+
Value operandB = op.getOperand(1);
213+
Type opType = operandA.getType();
214+
auto conOp =
215+
mlir::dyn_cast<mlir::arith::ConstantOp>(operandB.getDefiningOp());
216+
217+
if (!conOp)
218+
return failure();
219+
220+
auto iAttr = dyn_cast<mlir::SplatElementsAttr>(conOp.getValue());
221+
222+
if (!iAttr)
223+
return failure();
224+
225+
int64_t power = iAttr.getSplatValue<int64_t>();
226+
bool neg = power < 0;
227+
int64_t absPower = std::abs(power);
228+
Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
229+
Value res = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
230+
231+
while (absPower > 0) {
232+
233+
if (absPower & 1)
234+
res = b.create<arith::MulFOp>(opType, operandA, res);
235+
236+
absPower = absPower >> 1;
237+
operandA = b.create<arith::MulFOp>(opType, operandA, operandA);
238+
}
239+
240+
if (neg)
241+
res = b.create<arith::DivFOp>(opType, one, res);
242+
243+
rewriter.replaceOp(op, res);
244+
return success();
245+
}
246+
205247
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206248
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
207249
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +559,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517559
patterns.add(convertPowfOp);
518560
}
519561

562+
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
563+
patterns.add(convertFPowIOp);
564+
}
565+
520566
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
521567
patterns.add(convertRoundOp);
522568
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,60 @@ func.func @roundeven16(%arg: f16) -> f16 {
511511
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
512512

513513
// CHECK: return %[[COPYSIGN]] : f16
514+
515+
// -----
516+
517+
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518+
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519+
%1 = arith.constant dense<-3> : tensor<8xi64>
520+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521+
return %2 : tensor<8xf32>
522+
}
523+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
524+
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
525+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
526+
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
527+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
528+
// CHECK: return %[[INV]] : tensor<8xf32>
529+
530+
// -----
531+
532+
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
533+
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
534+
%1 = arith.constant dense<-4> : tensor<8xi64>
535+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
536+
return %2 : tensor<8xf32>
537+
}
538+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
539+
// CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
540+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
541+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
542+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
543+
// CHECK: return %[[INV]] : tensor<8xf32>
544+
545+
// -----
546+
547+
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
548+
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
549+
%1 = arith.constant dense<5> : tensor<8xi64>
550+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
551+
return %2 : tensor<8xf32>
552+
}
553+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
554+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
555+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
556+
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
557+
// CHECK: return %[[PW5]] : tensor<8xf32>
558+
559+
// -----
560+
561+
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
562+
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
563+
%1 = arith.constant dense<4> : tensor<8xi64>
564+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
565+
return %2 : tensor<8xf32>
566+
}
567+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
568+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
569+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
570+
// CHECK: return %[[PW4]] : tensor<8xf32>

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
4646
populateExpandFloorFPattern(patterns);
4747
populateExpandCeilFPattern(patterns);
4848
populateExpandPowFPattern(patterns);
49+
populateExpandFPowIPattern(patterns);
4950
populateExpandRoundFPattern(patterns);
5051
populateExpandRoundEvenPattern(patterns);
5152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)