Skip to content

Commit 68a1944

Browse files
authored
[mlir][vector] Project out anonymous bounds in ScalableValueBoundsConstraintSet (#96499)
If we don't eliminate these columns, then in some cases we fail to compute a scalable bound. Test case reduced from a real-world example.
1 parent 40278bb commit 68a1944

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
1010
#include "mlir/Dialect/Vector/IR/VectorOps.h"
11+
1112
namespace mlir::vector {
1213

1314
FailureOr<ConstantOrScalableBound::BoundSize>
@@ -74,6 +75,7 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
7475
return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
7576
};
7677
scalableCstr.projectOut(projectOutFn);
78+
scalableCstr.projectOutAnonymous(/*except=*/pos);
7779
// Also project out local variables (these are not tracked by the
7880
// ValueBoundsConstraintSet).
7981
for (unsigned i = 0, e = scalableCstr.cstr.getNumLocalVars(); i < e; ++i) {

mlir/test/Dialect/Vector/test-scalable-bounds.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,31 @@ func.func @unsupported_negative_mod() {
215215
"test.some_use"(%bound) : (index) -> ()
216216
return
217217
}
218+
219+
// -----
220+
221+
// CHECK: #[[$SCALABLE_BOUND_MAP_5:.*]] = affine_map<()[s0] -> (s0 * 4)>
222+
223+
// CHECK-LABEL: @extract_slice_loop
224+
// CHECK: %[[VSCALE:.*]] = vector.vscale
225+
// CHECK: %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_5]]()[%[[VSCALE]]]
226+
// CHECK: "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
227+
228+
func.func @extract_slice_loop(%tensor: tensor<1x1x3x?xf32>) {
229+
%vscale = vector.vscale
230+
%c0 = arith.constant 0 : index
231+
%c1 = arith.constant 1 : index
232+
%c2 = arith.constant 2 : index
233+
%c3 = arith.constant 3 : index
234+
%c4 = arith.constant 4 : index
235+
%cst = arith.constant 0.0 : f32
236+
%c4_vscale = arith.muli %c4, %vscale : index
237+
%slice = tensor.extract_slice %tensor[0, 0, 0, 0] [1, 1, 3, %c4_vscale] [1, 1, 1, 1] : tensor<1x1x3x?xf32> to tensor<1x3x?xf32>
238+
%15 = scf.for %arg6 = %c0 to %c3 step %c1 iter_args(%arg = %slice) -> (tensor<1x3x?xf32>) {
239+
%dim = tensor.dim %arg, %c2 : tensor<1x3x?xf32>
240+
%bound = "test.reify_bound"(%dim) {type = "LB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
241+
"test.some_use"(%bound) : (index) -> ()
242+
scf.yield %arg : tensor<1x3x?xf32>
243+
}
244+
return
245+
}

0 commit comments

Comments
 (0)