Skip to content

Commit cfe043c

Browse files
authored
[mlir][linalg] Restrict scalable vectorisation (#98639)
Updates `vectorizeScalableVectorPrecondition` so that scalable vectorisation is only applied in well understood and tested scenarios. It's unlikely that we would ever want an arbitrary dimension to be scalable. While the Linalg vectoriser should be flexible enough to handle all possibilities: * in more "exotic" cases, we are likely to struggle with lowerings further down the compilation stack, * it would be impractical given the limitations of LLVM (which usually reflect the limitations of actual hardware) - e.g. no support for "scalable" arrays of scalable or fixed width vectors (*). Ultimately, the goal of this patch is to better document what's currently supported. While this PR adds some new restrictions, no existing tests are affected. (*) At MLIR vector level that would correspond to e.g. `vector<[4]x8xf32>`.
1 parent 2df9fd7 commit cfe043c

File tree

2 files changed

+128
-10
lines changed

2 files changed

+128
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,26 +1945,79 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
19451945
return success();
19461946
}
19471947

1948-
/// Preconditions for scalable vectors.
1948+
/// Preconditions for scalable vectors. This is quite restrictive - it models
1949+
/// the fact that in practice we would only make selected dimensions scalable.
19491950
static LogicalResult
19501951
vectorizeScalableVectorPrecondition(Operation *op,
19511952
ArrayRef<int64_t> inputVectorSizes,
19521953
ArrayRef<bool> inputScalableVecDims) {
19531954
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
19541955
"Number of input vector sizes and scalable dims doesn't match");
19551956

1956-
if (inputVectorSizes.empty())
1957-
return success();
1957+
size_t numOfScalableDims =
1958+
llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
19581959

1959-
bool isScalable = inputScalableVecDims.back();
1960-
if (!isScalable)
1960+
if (numOfScalableDims == 0)
19611961
return success();
19621962

1963-
// Only element-wise and 1d depthwise conv ops supported in the presence of
1964-
// scalable dims.
19651963
auto linalgOp = dyn_cast<LinalgOp>(op);
1966-
return success(linalgOp && (isElementwise(linalgOp) ||
1967-
isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
1964+
1965+
// Cond 1: There's been no need for scalable vectorisation of
1966+
// non-linalg Ops so far
1967+
if (!linalgOp)
1968+
return failure();
1969+
1970+
// Cond 2: There's been no need for more than 2 scalable dims so far
1971+
if (numOfScalableDims > 2)
1972+
return failure();
1973+
1974+
// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
1975+
// it matches one of the supported cases:
1976+
// 1. exactly 1 dim is scalable and that's the _last_ parallel dim
1977+
// 2. exactly 2 dims are scalable and those are the _last two adjacent_
1978+
// parallel dims
1979+
// The 2nd restriction above means that only Matmul-like Ops are supported
1980+
// when 2 dims are scalable, e.g. :
1981+
// * iterators = [parallel, parallel, reduction]
1982+
// * scalable flags = [true, true, false]
1983+
1984+
// Find the first scalable flag
1985+
bool seenParalell = false;
1986+
auto iterators = linalgOp.getIteratorTypesArray();
1987+
SmallVector<bool> scalableFlags(inputScalableVecDims);
1988+
while (!scalableFlags.back()) {
1989+
seenParalell |= (iterators.back() == utils::IteratorType::parallel);
1990+
1991+
iterators.pop_back();
1992+
scalableFlags.pop_back();
1993+
}
1994+
1995+
// TODO: Support scalable vectorisation for reduction dims
1996+
if (iterators.back() == utils::IteratorType::reduction)
1997+
return failure();
1998+
1999+
// If this is not the _last_ parallel dim, 1. above is not met
2000+
if (seenParalell)
2001+
return failure();
2002+
2003+
// If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2004+
// supported for which expect the folowing config:
2005+
// * iterators = [parallel, parallel, reduction]
2006+
// * scalable flags = [true, true, false]
2007+
if (numOfScalableDims == 2) {
2008+
scalableFlags.pop_back();
2009+
iterators.pop_back();
2010+
2011+
if (!scalableFlags.back() ||
2012+
(iterators.back() != utils::IteratorType::parallel))
2013+
return failure();
2014+
}
2015+
2016+
// Cond 4: Only the following ops are supported in the
2017+
// presence of scalable vectors
2018+
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2019+
isa<linalg::MatmulTransposeAOp>(op) ||
2020+
isa<linalg::DepthwiseConv1DNwcWcOp>(op));
19682021
}
19692022

19702023
LogicalResult mlir::linalg::vectorizeOpPrecondition(

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ module attributes {transform.with_named_sequence} {
110110
}
111111
}
112112

113-
// -----
113+
// -----
114114

115115
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
116116
%pad = arith.constant 0.000000e+00 : f32
@@ -126,3 +126,68 @@ module attributes {transform.with_named_sequence} {
126126
transform.yield
127127
}
128128
}
129+
130+
// -----
131+
132+
func.func @linalg_reduce_scalable(%input: tensor<?xf32>,
133+
%acc: tensor<f32>) -> tensor<f32> {
134+
135+
// expected-error @+1 {{Attempted to vectorize, but failed}}
136+
%0 = linalg.reduce ins(%input : tensor<?xf32>) outs(%acc : tensor<f32>) dimensions = [0]
137+
(%in: f32, %init: f32) {
138+
%0 = arith.addf %in, %init : f32
139+
linalg.yield %0 : f32
140+
}
141+
return %0 : tensor<f32>
142+
}
143+
144+
module attributes {transform.with_named_sequence} {
145+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
146+
%0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
147+
transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
148+
transform.yield
149+
}
150+
}
151+
152+
// -----
153+
154+
func.func @linalg_generic_scalable_reduction_dim(%input: tensor<?x?xf32>,
155+
%acc: tensor<?xf32>) -> tensor<?xf32> {
156+
157+
// expected-error @+1 {{Attempted to vectorize, but failed}}
158+
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
159+
affine_map<(d0, d1) -> (d0)>],
160+
iterator_types = ["parallel", "reduction"] }
161+
ins(%input : tensor<?x?xf32>)
162+
outs(%acc : tensor<?xf32>) {
163+
^bb(%in: f32, %out: f32) :
164+
%0 = arith.addf %in, %out : f32
165+
linalg.yield %0 : f32
166+
} -> tensor<?xf32>
167+
return %0 : tensor<?xf32>
168+
}
169+
170+
module attributes {transform.with_named_sequence} {
171+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
172+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
173+
transform.structured.vectorize %0 vector_sizes [1, [4]] : !transform.any_op
174+
transform.yield
175+
}
176+
}
177+
178+
// -----
179+
180+
func.func @linalg_matmul_scalable_leading_parallel_dim(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
181+
// expected-error @+1 {{Attempted to vectorize, but failed}}
182+
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
183+
outs(%C: memref<?x?xf32>)
184+
return
185+
}
186+
187+
module attributes {transform.with_named_sequence} {
188+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
189+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
190+
transform.structured.vectorize %matmul vector_sizes [[8], 16, 4] : !transform.any_op
191+
transform.yield
192+
}
193+
}

0 commit comments

Comments
 (0)