Skip to content

Commit 5f6c036

Browse files
committed
[mlir][linalg] Extend scalable vectorisation support
This patch simply removes one of the pre-conditions of scalable vectorisation. Namely, that only the trailing vector dimension can be scalable, e.g. vector<2x4x[8]xi32> This limitation can be lifted following the recent work to support scalable vectorisation in Linalg, most notably: * https://reviews.llvm.org/D154336 * https://reviews.llvm.org/D153372 A test is added that demonstrates scalable vectorisation of `linalg.matmul`. This is simply a copy of a similar test for fixed-width vectorisation. The latter is updated - check lines are simplified and annotated with regex patterns. This in the spirit of [1]: Tests should be minimal, and only check what is absolutely necessary. Note that this change is just one step towards scalable vectorisation in Linalg. It will allow us to exercise new code paths in the context of scalable vectors in Linalg and hence make further progress in the forthcoming patches. [1] https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices Differential Revision: https://reviews.llvm.org/D158423
1 parent 007b41b commit 5f6c036

File tree

2 files changed

+52
-51
lines changed

2 files changed

+52
-51
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,6 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
169169
return res;
170170
}
171171

172-
/// Return true if the scalable vector dimensions are supported. For now, we
173-
/// only support scalable vectors in the trailing dimension.
174-
static bool areValidScalableVecDims(ArrayRef<bool> scalableVecDims) {
175-
if (scalableVecDims.empty())
176-
return true;
177-
178-
auto isScalable = [](bool isScalableVecSize) { return isScalableVecSize; };
179-
if (std::any_of(scalableVecDims.begin(), scalableVecDims.end() - 1,
180-
isScalable)) {
181-
return false;
182-
}
183-
184-
return true;
185-
}
186-
187172
/// Contains the vectorization state and related methods used across the
188173
/// vectorization process of a given operation.
189174
struct VectorizationState {
@@ -217,12 +202,6 @@ struct VectorizationState {
217202
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
218203
}
219204

220-
// Make sure we don't end up with unsupported scalable vector dimensions
221-
// after the permutation. If so, we should bail out on that operation in the
222-
// scalable preconditions.
223-
assert(areValidScalableVecDims(scalableDims) &&
224-
"Permuted scalable vector dimensions are not supported");
225-
226205
return VectorType::get(vectorShape, elementType, scalableDims);
227206
}
228207

@@ -1630,11 +1609,6 @@ vectorizeScalableVectorPrecondition(Operation *op,
16301609
if (inputVectorSizes.empty())
16311610
return success();
16321611

1633-
if (!areValidScalableVecDims(inputScalableVecDims)) {
1634-
LDBG("Non-trailing scalable vector dimensions are not supported\n");
1635-
return failure();
1636-
}
1637-
16381612
bool isScalable = inputScalableVecDims.back();
16391613
if (!isScalable)
16401614
return success();

mlir/test/Dialect/Linalg/vectorization-masked.mlir

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -447,41 +447,68 @@ transform.sequence failures(propagate) {
447447

448448
// -----
449449

450-
func.func @vectorize_dynamic_matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
450+
func.func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
451451
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
452452
outs(%C: memref<?x?xf32>)
453453
return
454454
}
455455

456-
// CHECK-LABEL: func.func @vectorize_dynamic_matmul(
457-
// CHECK-SAME: %[[VAL_0:.*]]: memref<?x?xf32>, %[[VAL_1:.*]]: memref<?x?xf32>, %[[VAL_2:.*]]: memref<?x?xf32>) {
456+
// CHECK-LABEL: func.func @matmul(
457+
// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>, %[[C:.*]]: memref<?x?xf32>) {
458458
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
459-
// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref<?x?xf32>
459+
// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref<?x?xf32>
460460
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
461-
// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?x?xf32>
461+
// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref<?x?xf32>
462462
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
463-
// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32>
464-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
465-
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
466-
// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
467-
// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
468-
// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
469-
// CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
470-
// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
471-
// CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
472-
// CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
473-
// CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
474-
// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32>
475-
// CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
476-
// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
477-
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
478-
// CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
479-
// CHECK: return
480-
// CHECK: }
463+
// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
464+
// CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
465+
// CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
466+
// CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
467+
// CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
468+
// CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
469+
// CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
470+
// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x16x4xf32>
471+
// CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
472+
// CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
473+
// CHECK: %[[C2:.*]] = arith.constant 0 : index
474+
// CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
481475

482476
transform.sequence failures(propagate) {
483477
^bb1(%arg1: !transform.any_op):
484-
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
485-
transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
478+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
479+
transform.structured.masked_vectorize %matmul vector_sizes [8, 16, 4] : !transform.any_op
480+
}
481+
482+
// -----
483+
484+
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
485+
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
486+
outs(%C: memref<?x?xf32>)
487+
return
486488
}
487489

490+
// CHECK-LABEL: func.func @matmul_scalable(
491+
// CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>, %[[C:.*]]: memref<?x?xf32>) {
492+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
493+
// CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref<?x?xf32>
494+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
495+
// CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref<?x?xf32>
496+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
497+
// CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
498+
// CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
499+
// CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<8x4xi1> -> vector<8x[16]x4xf32>
500+
// CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x[16]xi1>
501+
// CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<4x[16]xi1> -> vector<8x[16]x4xf32>
502+
// CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x[16]xi1>
503+
// CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x[16]xf32> } : vector<8x[16]xi1> -> vector<8x[16]xf32>
504+
// CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x[16]x4xf32>
505+
// CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x[16]x4xi1>
506+
// CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x[16]x4xf32> to vector<8x[16]xf32> } : vector<8x[16]x4xi1> -> vector<8x[16]xf32>
507+
// CHECK: %[[C2:.*]] = arith.constant 0 : index
508+
// CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x[16]xf32>, memref<?x?xf32> } : vector<8x[16]xi1>
509+
510+
transform.sequence failures(propagate) {
511+
^bb1(%arg1: !transform.any_op):
512+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
513+
transform.structured.masked_vectorize %matmul vector_sizes [8, [16], 4] : !transform.any_op
514+
}

0 commit comments

Comments
 (0)