Skip to content

Commit 576b184

Browse files
committed
[mlir][vector] Add support for scalable vectors in trimLeadingOneDims
This patch updates one specific hook in "VectorDropLeadUnitDim.cpp" to make sure that "scalable dims" are handled correctly. While this change affects multiple patterns, I am only adding one regression tests that captures one specific case that affects me right now. I am also adding Vector dialect to the list of dependencies of `-test-vector-to-vector-lowering`. Otherwise my test case won't work as a standalone test. Differential Revision: https://reviews.llvm.org/D157993
1 parent 386aa2a commit 576b184

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,23 @@ using namespace mlir::vector;
2323
// Returns `vector<1xT>` if `oldType` only has one element.
2424
static VectorType trimLeadingOneDims(VectorType oldType) {
2525
ArrayRef<int64_t> oldShape = oldType.getShape();
26-
ArrayRef<int64_t> newShape =
27-
oldShape.drop_while([](int64_t dim) { return dim == 1; });
26+
ArrayRef<int64_t> newShape = oldShape;
27+
28+
ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
29+
ArrayRef<bool> newScalableDims = oldScalableDims;
30+
31+
while (!newShape.empty() && newShape.front() == 1 &&
32+
!newScalableDims.front()) {
33+
newShape = newShape.drop_front(1);
34+
newScalableDims = newScalableDims.drop_front(1);
35+
}
36+
2837
// Make sure we have at least 1 dimension per vector type requirements.
29-
if (newShape.empty())
38+
if (newShape.empty()) {
3039
newShape = oldShape.take_back();
31-
return VectorType::get(newShape, oldType.getElementType());
40+
newScalableDims = oldType.getScalableDims().take_back();
41+
}
42+
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
3243
}
3344

3445
/// Return a smallVector of size `rank` containing all zeros.

mlir/test/Dialect/Vector/vector-transforms.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ func.func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
1818
return %1: vector<4x2xf32>
1919
}
2020

21+
// Regression test. Previously, this example would trigger
22+
// CastAwayElementwiseLeadingOneDim as:
23+
// * `vector<2x[4]x1xf32>`, would be reformulated as
24+
// * `vector<2x4x1xf32>`.
25+
// With the updated shape, the conversion pattern would incorrectly assume that
26+
// some leading dims have been dropped.
27+
// CHECK-LABEL: func.func @no_change(
28+
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x[4]x1xf32>,
29+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2x[4]x1xf32>)
30+
// CHECK-NEXT: %[[VAL_2:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : vector<2x[4]x1xf32>
31+
// CHECK-NEXT: return %[[VAL_2]]
32+
func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) -> vector<2x[4]x1xf32> {
33+
%1 = arith.mulf %arg0, %arg1 : vector<2x[4]x1xf32>
34+
return %1 : vector<2x[4]x1xf32>
35+
}
36+
2137
// CHECK-LABEL: func @add4x4
2238
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
2339
// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct TestVectorToVectorLowering
5555

5656
void getDependentDialects(DialectRegistry &registry) const override {
5757
registry.insert<affine::AffineDialect>();
58+
registry.insert<vector::VectorDialect>();
5859
}
5960

6061
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),

0 commit comments

Comments
 (0)