-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][NFC] IntegerRangeAnalysis: don't loop over splat attr #115399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Ian Wood (IanWood1) ChangesReland #115229 which was reverted by #115388 because it was hitting an assertion in IREE. From the original change: If the The problem with the original implementation is that Full diff: https://github.com/llvm/llvm-project/pull/115399.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 8682294c8a6972..f3413c1c30fadc 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -42,6 +42,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
}
if (auto arrayCstAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+ if (arrayCstAttr.isSplat()) {
+ setResultRange(getResult(), ConstantIntRanges::constant(
+ arrayCstAttr.getSplatValue<APInt>()));
+ return;
+ }
+
std::optional<ConstantIntRanges> result;
for (const APInt &val : arrayCstAttr) {
auto range = ConstantIntRanges::constant(val);
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 09dfe932a52323..e958ecaad45444 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -17,6 +17,12 @@ func.func @constant_splat() -> vector<8xi32> {
func.return %1 : vector<8xi32>
}
+// CHECK-LABEL: func @float_constant_splat
+func.func @float_constant_splat() -> vector<8xf32> {
+ %0 = arith.constant dense<3.0> : vector<8xf32>
+ func.return %0: vector<8xf32>
+}
+
// CHECK-LABEL: func @vector_splat
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
|
@llvm/pr-subscribers-mlir-vector Author: Ian Wood (IanWood1) ChangesReland #115229 which was reverted by #115388 because it was hitting an assertion in IREE. From the original change: If the The problem with the original implementation is that Full diff: https://github.com/llvm/llvm-project/pull/115399.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 8682294c8a6972..f3413c1c30fadc 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -42,6 +42,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
}
if (auto arrayCstAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
+ if (arrayCstAttr.isSplat()) {
+ setResultRange(getResult(), ConstantIntRanges::constant(
+ arrayCstAttr.getSplatValue<APInt>()));
+ return;
+ }
+
std::optional<ConstantIntRanges> result;
for (const APInt &val : arrayCstAttr) {
auto range = ConstantIntRanges::constant(val);
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 09dfe932a52323..e958ecaad45444 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -17,6 +17,12 @@ func.func @constant_splat() -> vector<8xi32> {
func.return %1 : vector<8xi32>
}
+// CHECK-LABEL: func @float_constant_splat
+func.func @float_constant_splat() -> vector<8xf32> {
+ %0 = arith.constant dense<3.0> : vector<8xf32>
+ func.return %0: vector<8xf32>
+}
+
// CHECK-LABEL: func @vector_splat
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @IanWood1
@@ -42,6 +42,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, | |||
} | |||
if (auto arrayCstAttr = | |||
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) { | |||
if (arrayCstAttr.isSplat()) { | |||
setResultRange(getResult(), ConstantIntRanges::constant( | |||
arrayCstAttr.getSplatValue<APInt>())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative is that you can still cast to SplatElementsAttr and then check the element type. Either way is fine.
This pass is causing long compilation times for llama3 405b (even when cherry-picking llvm/llvm-project#115399). The majority of the time is spent in this one pass. The compilation times improve when calling `eraseState` only when ops are deleted. This is similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and `IntRangeOptimizations.cpp`. It appears this function loops over all `LatticeAnchors` on each invocation to find the one to delete, causing it to be slow. My (nonrigorous) experiment showed a decrease from 18 min to 3 min compile time. My main concern here would be this affecting correctness, as I don't know if this has unaccounted for side effects. Signed-off-by: Ian Wood <[email protected]>
…5399) Reland llvm#115229 which was reverted by llvm#115388 because it was hitting an assertion in IREE. From the original change: If the `DenseIntElementsAttr` is a splat value, there is no need to loop over the entire attr. Instead, just update with the splat value. The problem with the original implementation is that `SplatElementsAttr` might be an attr of non `APInt` (e.g. float) elements. Instead, check if `DenseIntElementsAttr` is splat and use the splat value. Added a test to ensure there's no crash when handling float attrs.
…-org#19130) This pass is causing long compilation times for llama3 405b (even when cherry-picking llvm/llvm-project#115399). The majority of the time is spent in this one pass. The compilation times improve when calling `eraseState` only when ops are deleted. This is similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and `IntRangeOptimizations.cpp`. It appears this function loops over all `LatticeAnchors` on each invocation to find the one to delete, causing it to be slow. My (nonrigorous) experiment showed a decrease from 18 min to 3 min compile time. My main concern here would be this affecting correctness, as I don't know if this has unaccounted for side effects. Signed-off-by: Ian Wood <[email protected]>
…-org#19130) This pass is causing long compilation times for llama3 405b (even when cherry-picking llvm/llvm-project#115399). The majority of the time is spent in this one pass. The compilation times improve when calling `eraseState` only when ops are deleted. This is similar to the upstream listeners in `UnsignedWhenEquivalent.cpp` and `IntRangeOptimizations.cpp`. It appears this function loops over all `LatticeAnchors` on each invocation to find the one to delete, causing it to be slow. My (nonrigorous) experiment showed a decrease from 18 min to 3 min compile time. My main concern here would be this affecting correctness, as I don't know if this has unaccounted for side effects. Signed-off-by: Ian Wood <[email protected]> Signed-off-by: Giacomo Serafini <[email protected]>
Reland #115229 which was reverted by #115388 because it was hitting an assertion in IREE. From the original change: If the
DenseIntElementsAttr
is a splat value, there is no need to loop over the entire attr. Instead, just update with the splat value.The problem with the original implementation is that
SplatElementsAttr
might be an attr of nonAPInt
(e.g. float) elements. Instead, check ifDenseIntElementsAttr
is splat and use the splat value. Added a test to ensure there's no crash when handling float attrs.