Skip to content

Commit 2d4e856

Browse files
bviyerrsuderman
authored andcommitted
[mlir][math] Expand math.powf to exp, log and multiply
Powf functions are pushed directly to libm. This is problematic for situations where libm is not available. This patch will decompose the powf function into log of exponent multiplied by log of base and raise it to the exp. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D148164
1 parent d3a2436 commit 2d4e856

File tree

5 files changed

+91
-0
lines changed

5 files changed

+91
-0
lines changed

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
2020
void populateExpandFloorFPattern(RewritePatternSet &patterns);
2121
void populateExpandCeilFPattern(RewritePatternSet &patterns);
2222
void populateExpandExp2FPattern(RewritePatternSet &patterns);
23+
void populateExpandPowFPattern(RewritePatternSet &patterns);
2324
void populateExpandRoundFPattern(RewritePatternSet &patterns);
2425
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
2526

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,19 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
157157
rewriter.replaceOp(op, ret);
158158
return success();
159159
}
160+
// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
161+
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
162+
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
163+
Value operandA = op.getOperand(0);
164+
Value operandB = op.getOperand(1);
165+
Type opType = operandA.getType();
166+
167+
Value logA = b.create<math::LogOp>(opType, operandA);
168+
Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
169+
Value expResult = b.create<math::ExpOp>(opType, mult);
170+
rewriter.replaceOp(op, expResult);
171+
return success();
172+
}
160173

161174
// exp2f(float x) -> exp(x * ln(2))
162175
// Proof: Let's say 2^x = y
@@ -264,6 +277,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
264277
patterns.add(convertExp2fOp);
265278
}
266279

280+
void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
281+
patterns.add(convertPowfOp);
282+
}
283+
267284
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
268285
patterns.add(convertRoundOp);
269286
}

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,16 @@ func.func @roundf_func(%a: f64) -> f64 {
207207
%ret = math.round %a : f64
208208
return %ret : f64
209209
}
210+
211+
// -----
212+
213+
// CHECK-LABEL: func @powf_func
214+
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
215+
func.func @powf_func(%a: f64, %b: f64) ->f64 {
216+
// CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
217+
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
218+
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
219+
// CHECK: return [[EXPR]]
220+
%ret = math.powf %a, %b : f64
221+
return %ret : f64
222+
}

mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
4343
populateExpandFmaFPattern(patterns);
4444
populateExpandFloorFPattern(patterns);
4545
populateExpandCeilFPattern(patterns);
46+
populateExpandPowFPattern(patterns);
4647
populateExpandRoundFPattern(patterns);
4748
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
4849
}

mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,68 @@ func.func @roundf() {
100100
return
101101
}
102102

103+
// -------------------------------------------------------------------------- //
104+
// pow.
105+
// -------------------------------------------------------------------------- //
106+
func.func @func_powff64(%a : f64, %b : f64) {
107+
%r = math.powf %a, %b : f64
108+
vector.print %r : f64
109+
return
110+
}
111+
112+
func.func @powf() {
113+
// CHECK: 16
114+
%a = arith.constant 4.0 : f64
115+
%a_p = arith.constant 2.0 : f64
116+
call @func_powff64(%a, %a_p) : (f64, f64) -> ()
117+
118+
// CHECK: -nan
119+
%b = arith.constant -3.0 : f64
120+
%b_p = arith.constant 3.0 : f64
121+
call @func_powff64(%b, %b_p) : (f64, f64) -> ()
122+
123+
// CHECK: 2.343
124+
%c = arith.constant 2.343 : f64
125+
%c_p = arith.constant 1.000 : f64
126+
call @func_powff64(%c, %c_p) : (f64, f64) -> ()
127+
128+
// CHECK: 0.176171
129+
%d = arith.constant 4.25 : f64
130+
%d_p = arith.constant -1.2 : f64
131+
call @func_powff64(%d, %d_p) : (f64, f64) -> ()
132+
133+
// CHECK: 1
134+
%e = arith.constant 4.385 : f64
135+
%e_p = arith.constant 0.00 : f64
136+
call @func_powff64(%e, %e_p) : (f64, f64) -> ()
137+
138+
// CHECK: 6.62637
139+
%f = arith.constant 4.835 : f64
140+
%f_p = arith.constant 1.2 : f64
141+
call @func_powff64(%f, %f_p) : (f64, f64) -> ()
142+
143+
// CHECK: -nan
144+
%g = arith.constant 0xff80000000000000 : f64
145+
call @func_powff64(%g, %g) : (f64, f64) -> ()
146+
147+
// CHECK: nan
148+
%h = arith.constant 0x7fffffffffffffff : f64
149+
call @func_powff64(%h, %h) : (f64, f64) -> ()
150+
151+
// CHECK: nan
152+
%i = arith.constant 1.0 : f64
153+
call @func_powff64(%i, %h) : (f64, f64) -> ()
154+
155+
// CHECK: inf
156+
%j = arith.constant 29385.0 : f64
157+
%j_p = arith.constant 23598.0 : f64
158+
call @func_powff64(%j, %j_p) : (f64, f64) -> ()
159+
return
160+
}
103161

104162
func.func @main() {
105163
call @exp2f() : () -> ()
106164
call @roundf() : () -> ()
165+
call @powf() : () -> ()
107166
return
108167
}

0 commit comments

Comments
 (0)