-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][math] powi with negative exponent should invert at the end #135735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][math] powi with negative exponent should invert at the end #135735
Conversation
Previously, an FPowI operation would invert the base *before* performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect. See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Asher Mancinelli (ashermancinelli) ChangesPreviously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect. See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c Full diff: https://github.com/llvm/llvm-project/pull/135735.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index dcace489673f0..13e2a4b5541b2 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -197,11 +197,6 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
if (exponentValue > exponentThreshold)
return failure();
- // Inverse the base for negative exponent, i.e. for
- // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
- if (exponentIsNegative)
- base = rewriter.create<DivOpTy>(loc, bcast(one), base);
-
Value result = base;
// Transform to naive sequence of multiplications:
// * For positive exponent case replace:
@@ -215,6 +210,11 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
for (unsigned i = 1; i < exponentValue; ++i)
result = rewriter.create<MulOpTy>(loc, result, base);
+ // Inverse the base for negative exponent, i.e. for
+ // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
+ if (exponentIsNegative)
+ result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index a97ecc52a17e9..e0e2b9853a2a1 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -135,11 +135,11 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
- // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
- // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
- // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
- // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
- // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+ // CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]]
+ // CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
%c1 = arith.constant 2 : i32
%v1 = arith.constant dense <2> : vector<4xi32>
%0 = math.ipowi %arg0, %c1 : i32
@@ -162,13 +162,13 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
// CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
// CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
// CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
- // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
- // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
- // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
- // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
- // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
- // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
- // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+ // CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]]
+ // CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
%c1 = arith.constant 3 : i32
%v1 = arith.constant dense <3> : vector<4xi32>
%0 = math.ipowi %arg0, %c1 : i32
@@ -225,11 +225,11 @@ func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
// CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
- // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
- // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
- // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
- // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
- // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+ // CHECK: %[[SMUL:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL]]
+ // CHECK: %[[VMUL:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
%c1 = arith.constant 2 : i32
%v1 = arith.constant dense <2> : vector<4xi32>
%0 = math.fpowi %arg0, %c1 : f32, i32
@@ -252,13 +252,13 @@ func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf
// CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
// CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
// CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
- // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
- // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
- // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
- // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
- // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
- // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
- // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+ // CHECK: %[[SMUL1:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[ARG0]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL2]]
+ // CHECK: %[[VMUL1:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[ARG1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL2]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
%c1 = arith.constant 3 : i32
%v1 = arith.constant dense <3> : vector<4xi32>
%0 = math.fpowi %arg0, %c1 : f32, i32
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Asher!
…vm#135735) Previously, an FPowI operation would invert the base *before* performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect. See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c
Previously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.
See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c