Skip to content

Commit db8af5d

Browse files
ashermancinellivar-const
authored andcommitted
[mlir][math] powi with negative exponent should invert at the end (llvm#135735)
Previously, an FPowI operation would invert the base *before* performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect. See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c
1 parent e557d66 commit db8af5d

File tree

2 files changed

+29
-29
lines changed

2 files changed

+29
-29
lines changed

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,6 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
197197
if (exponentValue > exponentThreshold)
198198
return failure();
199199

200-
// Inverse the base for negative exponent, i.e. for
201-
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
202-
if (exponentIsNegative)
203-
base = rewriter.create<DivOpTy>(loc, bcast(one), base);
204-
205200
Value result = base;
206201
// Transform to naive sequence of multiplications:
207202
// * For positive exponent case replace:
@@ -215,6 +210,11 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
215210
for (unsigned i = 1; i < exponentValue; ++i)
216211
result = rewriter.create<MulOpTy>(loc, result, base);
217212

213+
// Inverse the base for negative exponent, i.e. for
214+
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
215+
if (exponentIsNegative)
216+
result = rewriter.create<DivOpTy>(loc, bcast(one), result);
217+
218218
rewriter.replaceOp(op, result);
219219
return success();
220220
}

mlir/test/Dialect/Math/algebraic-simplification.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32
135135
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
136136
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
137137
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
138-
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
139-
// CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
140-
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
141-
// CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
142-
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
138+
// CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
139+
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]]
140+
// CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
141+
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]]
142+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
143143
%c1 = arith.constant 2 : i32
144144
%v1 = arith.constant dense <2> : vector<4xi32>
145145
%0 = math.ipowi %arg0, %c1 : i32
@@ -162,13 +162,13 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
162162
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
163163
// CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
164164
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
165-
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
166-
// CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
167-
// CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
168-
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
169-
// CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
170-
// CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
171-
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
165+
// CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
166+
// CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]]
167+
// CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]]
168+
// CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
169+
// CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]]
170+
// CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]]
171+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
172172
%c1 = arith.constant 3 : i32
173173
%v1 = arith.constant dense <3> : vector<4xi32>
174174
%0 = math.ipowi %arg0, %c1 : i32
@@ -225,11 +225,11 @@ func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32
225225
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
226226
// CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
227227
// CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
228-
// CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
229-
// CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
230-
// CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
231-
// CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
232-
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
228+
// CHECK: %[[SMUL:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
229+
// CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL]]
230+
// CHECK: %[[VMUL:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
231+
// CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL]]
232+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
233233
%c1 = arith.constant 2 : i32
234234
%v1 = arith.constant dense <2> : vector<4xi32>
235235
%0 = math.fpowi %arg0, %c1 : f32, i32
@@ -252,13 +252,13 @@ func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf
252252
// CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
253253
// CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
254254
// CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
255-
// CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
256-
// CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
257-
// CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
258-
// CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
259-
// CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
260-
// CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
261-
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
255+
// CHECK: %[[SMUL1:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
256+
// CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[ARG0]]
257+
// CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL2]]
258+
// CHECK: %[[VMUL1:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
259+
// CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[ARG1]]
260+
// CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL2]]
261+
// CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
262262
%c1 = arith.constant 3 : i32
263263
%v1 = arith.constant dense <3> : vector<4xi32>
264264
%0 = math.fpowi %arg0, %c1 : f32, i32

0 commit comments

Comments
 (0)