Skip to content

Commit 687cdfd

Browse files
committed
[mlir][linalg] Relax scalable vectorization restrictions
Currently, the Linalg vectorizer disallows non-trailing parallel dimensions to be scalable, e.g., `vector_sizes [[8], 1]` (*), for cases like: ```mlir %0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32> ``` This restriction exists to avoid generating "scalable" arrays of aggregates, which LLVM does not support (multi-dim vectors are lowered into arrays of aggregates at the LLVM level). This patch relaxes that restriction when the trailing parallel vector dimension is `1`, e.g., for `vector_sizes [[8], 1]`. Such cases are safe since trailing unit dimensions can be collapsed. This relaxation is necessary to support scalable vectorization for tensor.pack, where inner tile sizes are `[8]` (scalable) and `1` (scalar). (*) Transform Dialect notation
1 parent 69d66fa commit 687cdfd

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,26 +2022,35 @@ vectorizeScalableVectorPrecondition(Operation *op,
20222022

20232023
// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
20242024
// it matches one of the supported cases:
2025-
// 1. exactly 1 dim is scalable and that's the _last_ parallel dim
2026-
// 2. exactly 2 dims are scalable and those are the _last two adjacent_
2027-
// parallel dims
2028-
// 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
2025+
// 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2026+
// (*).
2027+
// 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2028+
// parallel dims.
2029+
// 3. Exactly 1 reduction dim is scalable and that's the last (innermost) dim.
20292030
// The 2nd restriction above means that only Matmul-like Ops are supported
20302031
// when 2 dims are scalable, e.g. :
20312032
// * iterators = [parallel, parallel, reduction]
20322033
// * scalable flags = [true, true, false]
2034+
//
2035+
// (*) Non-unit dims get folded away in practice.
2036+
// TODO: Relax these conditions as good motivating examples are identified.
20332037

2034-
// Find the first scalable flag
2035-
bool seenParalell = false;
2038+
// Find the first scalable flag, and ...
2039+
bool seenNonUnitParallel = false;
20362040
auto iterators = linalgOp.getIteratorTypesArray();
20372041
SmallVector<bool> scalableFlags(inputScalableVecDims);
2038-
while (!scalableFlags.back()) {
2039-
seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2042+
int64_t idx = scalableFlags.size() - 1;
2043+
while (!scalableFlags[idx]) {
2044+
bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2045+
seenNonUnitParallel |=
2046+
(iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
20402047

20412048
iterators.pop_back();
20422049
scalableFlags.pop_back();
2050+
idx--;
20432051
}
20442052

2053+
// ... analyze the corresponding iterator.
20452054
switch (iterators.back()) {
20462055
case utils::IteratorType::reduction: {
20472056
// Check 3. above is met.
@@ -2059,7 +2068,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
20592068
}
20602069
case utils::IteratorType::parallel: {
20612070
// Check 1. and 2. above are met.
2062-
if (seenParalell) {
2071+
if (seenNonUnitParallel) {
20632072
LDBG("Inner parallel dim not requested for scalable "
20642073
"vectorization\n");
20652074
return failure();

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,27 @@ module attributes {transform.with_named_sequence} {
122122

123123
// -----
124124

125-
func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
125+
// NOTE: Often, non-trailing scalable sizes are problematic - there are no
126+
// "scalable" arrays of vectors at the LLVM level (multi-dim vectors are
127+
// decomposed into arrays of aggregates). However, the trailing dim in this
128+
// case is 1 and that can be folded away later.
129+
130+
func.func @vectorize_dynamic_fill_leading_scalable(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
126131
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
127132
return %0 : tensor<?x?xf32>
128133
}
129134

130-
// CHECK-LABEL: func.func @vectorize_dynamic_fill
135+
// CHECK-LABEL: func.func @vectorize_dynamic_fill_leading_scalable
131136
// CHECK: %[[DIM0:.*]] = tensor.dim
132137
// CHECK: %[[DIM1:.*]] = tensor.dim
133-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
134-
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
135-
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
138+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x1xi1>
139+
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<[8]x1xf32>
140+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<[8]x1xf32>, tensor<?x?xf32> } : vector<[8]x1xi1>
136141

137142
module attributes {transform.with_named_sequence} {
138143
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
139144
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
140-
transform.structured.vectorize %0 vector_sizes [8, [16]] : !transform.any_op
145+
transform.structured.vectorize %0 vector_sizes [[8], 1] : !transform.any_op
141146
transform.yield
142147
}
143148
}

0 commit comments

Comments
 (0)