Skip to content

Commit 58c30b4

Browse files
committed
[mlir][math] Expand powfI operation for constant power operand.
1 parent 89bae85 commit 58c30b4

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void populateExpandFloorFPattern(RewritePatternSet &patterns);
3636
void populateExpandCeilFPattern(RewritePatternSet &patterns);
3737
void populateExpandExp2FPattern(RewritePatternSet &patterns);
3838
void populateExpandPowFPattern(RewritePatternSet &patterns);
39+
void populateExpandFPowIPattern(RewritePatternSet &patterns);
3940
void populateExpandRoundFPattern(RewritePatternSet &patterns);
4041
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
4142
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,70 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
202202
rewriter.replaceOp(op, ret);
203203
return success();
204204
}
205+
206+
// Convert `math.fpowi` to a series of `arith.mulf` operations.
207+
// If the power is negative, we divide one by the result.
208+
static LogicalResult convertFPowICstOp(math::FPowIOp op,
209+
PatternRewriter &rewriter) {
210+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211+
Value base = op.getOperand(0);
212+
Value power = op.getOperand(1);
213+
Type baseType = base.getType();
214+
Value tempBase = op.getOperand(0);
215+
216+
Attribute cstAttr;
217+
if (!matchPattern(power, m_Constant(&cstAttr)))
218+
return failure();
219+
220+
int64_t powerInt;
221+
222+
// Check for Splat or Integer Attrs.
223+
if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr)) {
224+
powerInt = splatAttr.getSplatValue<int64_t>();
225+
} else if (auto iAttr = dyn_cast<IntegerAttr>(cstAttr)) {
226+
powerInt = iAttr.getInt();
227+
} else {
228+
return failure();
229+
}
230+
231+
bool isNegative = powerInt < 0;
232+
int64_t absPower = std::abs(powerInt);
233+
Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
234+
Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
235+
236+
Value zero = createFloatConst(op->getLoc(), baseType, 0.00, rewriter);
237+
Value negZero = createFloatConst(op->getLoc(), baseType, -0.00, rewriter);
238+
Value posInfinity =
239+
createFloatConst(op->getLoc(), baseType,
240+
std::numeric_limits<double_t>::infinity(), rewriter);
241+
Value negInfinity =
242+
createFloatConst(op->getLoc(), baseType,
243+
-std::numeric_limits<double_t>::infinity(), rewriter);
244+
245+
while (absPower > 0) {
246+
if (absPower & 1)
247+
res = b.create<arith::MulFOp>(baseType, tempBase, res);
248+
absPower >>= 1;
249+
tempBase = b.create<arith::MulFOp>(baseType, tempBase, tempBase);
250+
}
251+
252+
// Take care of UB in case of negative power.
253+
if (isNegative) {
254+
res = b.create<arith::DivFOp>(baseType, one, res);
255+
Value zeroEqCheck =
256+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, zero);
257+
Value negZeroEqCheck =
258+
b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, base, negZero);
259+
res =
260+
b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
261+
res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
262+
res);
263+
}
264+
265+
rewriter.replaceOp(op, res);
266+
return success();
267+
}
268+
205269
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
206270
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
207271
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
@@ -517,6 +581,10 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
517581
patterns.add(convertPowfOp);
518582
}
519583

584+
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
585+
patterns.add(convertFPowICstOp);
586+
}
587+
520588
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
521589
patterns.add(convertRoundOp);
522590
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,3 +511,88 @@ func.func @roundeven16(%arg: f16) -> f16 {
511511
// CHECK: %[[COPYSIGN:.*]] = math.copysign %[[RESULT]], %[[VAL_0]] : f16
512512

513513
// CHECK: return %[[COPYSIGN]] : f16
514+
515+
// -----
516+
517+
// CHECK-LABEL: func.func @math_fpowi_neg_odd_power
518+
func.func @math_fpowi_neg_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
519+
%1 = arith.constant dense<-3> : tensor<8xi64>
520+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
521+
return %2 : tensor<8xf32>
522+
}
523+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
524+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
525+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
526+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
527+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
528+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
529+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
530+
// CHECK: %[[CUBE:.*]] = arith.mulf %[[SQ]], %[[ARG0]] : tensor<8xf32>
531+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[CUBE]] : tensor<8xf32>
532+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
533+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
534+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
535+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
536+
// CHECK: return %[[UB2]] : tensor<8xf32>
537+
538+
// -----
539+
540+
// CHECK-LABEL: func.func @math_fpowi_neg_even_power
541+
func.func @math_fpowi_neg_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
542+
%1 = arith.constant dense<-4> : tensor<8xi64>
543+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
544+
return %2 : tensor<8xf32>
545+
}
546+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
547+
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32>
548+
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
549+
// CHECK-DAG: %[[CSTNEG0:.*]] = arith.constant dense<-0.000000e+00> : tensor<8xf32>
550+
// CHECK-DAG: %[[CSTINF:.*]] = arith.constant dense<0x7F800000> : tensor<8xf32>
551+
// CHECK-DAG: %[[CSTNEGINF:.*]] = arith.constant dense<0xFF800000> : tensor<8xf32>
552+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
553+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
554+
// CHECK: %[[INV:.*]] = arith.divf %[[CST1]], %[[PW4]] : tensor<8xf32>
555+
// CHECK: %[[CMP0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CST0]] : tensor<8xf32>
556+
// CHECK: %[[CMPNEG0:.*]] = arith.cmpf oeq, %[[ARG0]], %[[CSTNEG0]] : tensor<8xf32>
557+
// CHECK: %[[UB1:.*]] = arith.select %[[CMP0]], %[[CSTINF]], %[[INV]] : tensor<8xi1>, tensor<8xf32>
558+
// CHECK: %[[UB2:.*]] = arith.select %[[CMPNEG0]], %[[CSTNEGINF]], %[[UB1]] : tensor<8xi1>, tensor<8xf32>
559+
// CHECK: return %[[UB2]] : tensor<8xf32>
560+
561+
// -----
562+
563+
// CHECK-LABEL: func.func @math_fpowi_pos_odd_power
564+
func.func @math_fpowi_pos_odd_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
565+
%1 = arith.constant dense<5> : tensor<8xi64>
566+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
567+
return %2 : tensor<8xf32>
568+
}
569+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
570+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
571+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
572+
// CHECK: %[[PW5:.*]] = arith.mulf %[[PW4]], %[[ARG0]] : tensor<8xf32>
573+
// CHECK: return %[[PW5]] : tensor<8xf32>
574+
575+
// -----
576+
577+
// CHECK-LABEL: func.func @math_fpowi_pos_even_power
578+
func.func @math_fpowi_pos_even_power(%0 : tensor<8xf32>) -> tensor<8xf32> {
579+
%1 = arith.constant dense<4> : tensor<8xi64>
580+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi64>
581+
return %2 : tensor<8xf32>
582+
}
583+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>) -> tensor<8xf32> {
584+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
585+
// CHECK: %[[PW4:.*]] = arith.mulf %[[SQ]], %[[SQ]] : tensor<8xf32>
586+
// CHECK: return %[[PW4]] : tensor<8xf32>
587+
588+
// -----
589+
590+
// CHECK-LABEL: func.func @math_fpowi_even_scalar
591+
func.func @math_fpowi_even_scalar(%0 : f32) -> f32 {
592+
%pow = arith.constant 2 : i64
593+
%2 = math.fpowi %0, %pow : f32, i64
594+
return %2 : f32
595+
}
596+
// CHECK-SAME: (%[[ARG0:.*]]: f32) -> f32 {
597+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
598+
// CHECK: return %[[SQ]] : f32

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ void TestExpandMathPass::runOnOperation() {
4646
populateExpandFloorFPattern(patterns);
4747
populateExpandCeilFPattern(patterns);
4848
populateExpandPowFPattern(patterns);
49+
populateExpandFPowIPattern(patterns);
4950
populateExpandRoundFPattern(patterns);
5051
populateExpandRoundEvenPattern(patterns);
5152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)