Skip to content

Commit 441b82b

Browse files
authored
[mlir][NFC] IntegerRangeAnalysis: don't loop over splat attr (#115399)
Reland #115229 which was reverted by #115388 because it was hitting an assertion in IREE. From the original change: If the `DenseIntElementsAttr` is a splat value, there is no need to loop over the entire attr. Instead, just update with the splat value. The problem with the original implementation is that `SplatElementsAttr` might be an attr of non `APInt` (e.g. float) elements. Instead, check if `DenseIntElementsAttr` is splat and use the splat value. Added a test to ensure there's no crash when handling float attrs.
1 parent ccc9d7d commit 441b82b

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
4242
}
4343
if (auto arrayCstAttr =
4444
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45+
if (arrayCstAttr.isSplat()) {
46+
setResultRange(getResult(), ConstantIntRanges::constant(
47+
arrayCstAttr.getSplatValue<APInt>()));
48+
return;
49+
}
50+
4551
std::optional<ConstantIntRanges> result;
4652
for (const APInt &val : arrayCstAttr) {
4753
auto range = ConstantIntRanges::constant(val);

mlir/test/Dialect/Vector/int-range-interface.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ func.func @constant_splat() -> vector<8xi32> {
1717
func.return %1 : vector<8xi32>
1818
}
1919

20+
// CHECK-LABEL: func @float_constant_splat
21+
// Don't crash on splat floats.
22+
func.func @float_constant_splat() -> vector<8xf32> {
23+
%0 = arith.constant dense<3.0> : vector<8xf32>
24+
func.return %0: vector<8xf32>
25+
}
26+
2027
// CHECK-LABEL: func @vector_splat
2128
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
2229
func.func @vector_splat() -> vector<4xindex> {

0 commit comments

Comments
 (0)