Skip to content

[mlir][math] powi with negative exponent should invert at the end #135735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 15, 2025

Conversation

ashermancinelli
Copy link
Contributor

Previously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c

Previously, an FPowI operation would invert the base *before*
performing a sequence of multiplications, but this led to
discrepancies between LLVM pow intrinsic folding and that
coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the
calculation: compiler-rt/lib/builtins/powidf2.c
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Asher Mancinelli (ashermancinelli)

Changes

Previously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c


Full diff: https://github.com/llvm/llvm-project/pull/135735.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+5-5)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+24-24)
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index dcace489673f0..13e2a4b5541b2 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -197,11 +197,6 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   if (exponentValue > exponentThreshold)
     return failure();
 
-  // Inverse the base for negative exponent, i.e. for
-  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
-  if (exponentIsNegative)
-    base = rewriter.create<DivOpTy>(loc, bcast(one), base);
-
   Value result = base;
   // Transform to naive sequence of multiplications:
   //   * For positive exponent case replace:
@@ -215,6 +210,11 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
   for (unsigned i = 1; i < exponentValue; ++i)
     result = rewriter.create<MulOpTy>(loc, result, base);
 
+  // Inverse the base for negative exponent, i.e. for
+  // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
+  if (exponentIsNegative)
+    result = rewriter.create<DivOpTy>(loc, bcast(one), result);
+
   rewriter.replaceOp(op, result);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index a97ecc52a17e9..e0e2b9853a2a1 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -135,11 +135,11 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32
   // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
   // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
   // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  // CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]]
+  // CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 2 : i32
   %v1 = arith.constant dense <2> : vector<4xi32>
   %0 = math.ipowi %arg0, %c1 : i32
@@ -162,13 +162,13 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi
   // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
   // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
   // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  // CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]]
+  // CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 3 : i32
   %v1 = arith.constant dense <3> : vector<4xi32>
   %0 = math.ipowi %arg0, %c1 : i32
@@ -225,11 +225,11 @@ func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32
   // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
   // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
   // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  // CHECK: %[[SMUL:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL]]
+  // CHECK: %[[VMUL:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 2 : i32
   %v1 = arith.constant dense <2> : vector<4xi32>
   %0 = math.fpowi %arg0, %c1 : f32, i32
@@ -252,13 +252,13 @@ func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf
   // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]]
   // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
   // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]]
-  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]]
-  // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]]
-  // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]]
-  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]]
-  // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]]
-  // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]]
-  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  // CHECK: %[[SMUL1:.*]] = arith.mulf %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[ARG0]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL2]]
+  // CHECK: %[[VMUL1:.*]] = arith.mulf %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[ARG1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL2]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]]
   %c1 = arith.constant 3 : i32
   %v1 = arith.constant dense <3> : vector<4xi32>
   %0 = math.fpowi %arg0, %c1 : f32, i32

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Asher!

@ashermancinelli ashermancinelli merged commit 9ab2dea into llvm:main Apr 15, 2025
13 of 14 checks passed
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…vm#135735)

Previously, an FPowI operation would invert the base *before* performing
a sequence of multiplications, but this led to discrepancies between
LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the
calculation: compiler-rt/lib/builtins/powidf2.c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants