Skip to content

[mlir][linalg] Relax scalable vectorization restrictions #117991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2022,26 +2022,36 @@ vectorizeScalableVectorPrecondition(Operation *op,

// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
// it matches one of the supported cases:
// 1. exactly 1 dim is scalable and that's the _last_ parallel dim
// 2. exactly 2 dims are scalable and those are the _last two adjacent_
// parallel dims
// 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
// 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
// (*).
// 2. Exactly 2 dims are scalable and those are the _last two adjacent_
// parallel dims.
// 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
// dim.
// The 2nd restriction above means that only Matmul-like Ops are supported
// when 2 dims are scalable, e.g. :
// * iterators = [parallel, parallel, reduction]
// * scalable flags = [true, true, false]
//
// (*) Non-unit dims get folded away in practice.
// TODO: Relax these conditions as good motivating examples are identified.

// Find the first scalable flag
bool seenParalell = false;
// Find the first scalable flag.
bool seenNonUnitParallel = false;
auto iterators = linalgOp.getIteratorTypesArray();
SmallVector<bool> scalableFlags(inputScalableVecDims);
while (!scalableFlags.back()) {
seenParalell |= (iterators.back() == utils::IteratorType::parallel);
int64_t idx = scalableFlags.size() - 1;
while (!scalableFlags[idx]) {
bool isNonUnitDim = (inputVectorSizes[idx] != 1);
seenNonUnitParallel |=
(iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);

iterators.pop_back();
scalableFlags.pop_back();
--idx;
}

// Analyze the iterator corresponding to the first scalable dim.
switch (iterators.back()) {
case utils::IteratorType::reduction: {
// Check 3. above is met.
Expand All @@ -2059,7 +2069,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
}
case utils::IteratorType::parallel: {
// Check 1. and 2. above are met.
if (seenParalell) {
if (seenNonUnitParallel) {
LDBG("Inner parallel dim not requested for scalable "
"vectorization\n");
return failure();
Expand Down
17 changes: 11 additions & 6 deletions mlir/test/Dialect/Linalg/vectorization-scalable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,27 @@ module attributes {transform.with_named_sequence} {

// -----

func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
// NOTE: Often, non-trailing scalable sizes are problematic - there are no
// "scalable" arrays of vectors at the LLVM level (multi-dim vectors are
// decomposed into arrays of aggregates). However, the trailing dim in this
// case is 1 and that can be folded away later.

func.func @vectorize_dynamic_fill_leading_scalable(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @vectorize_dynamic_fill
// CHECK-LABEL: func.func @vectorize_dynamic_fill_leading_scalable
// CHECK: %[[DIM0:.*]] = tensor.dim
// CHECK: %[[DIM1:.*]] = tensor.dim
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x[16]xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x[16]xf32>
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x[16]xf32>, tensor<?x?xf32> } : vector<8x[16]xi1>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x1xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<[8]x1xf32>
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<[8]x1xf32>, tensor<?x?xf32> } : vector<[8]x1xi1>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [8, [16]] : !transform.any_op
transform.structured.vectorize %0 vector_sizes [[8], 1] : !transform.any_op
transform.yield
}
}
Expand Down