Skip to content

Commit 71441ed

Browse files
authored
[mlir][Vector] Add vector bitwidth target to xfer op flattening (#81966)
This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).
1 parent 162fa4d commit 71441ed

File tree

4 files changed

+137
-13
lines changed

4 files changed

+137
-13
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
330330
/// These patterns insert memref.collapse_shape + vector.shape_cast patterns
331331
/// to transform multiple small n-D transfers into a larger 1-D transfer where
332332
/// the memref contiguity properties allow it.
333-
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
334-
PatternBenefit benefit = 1);
333+
///
334+
/// Flattening is only applied if the bitwidth of the trailing vector dimension
335+
/// is smaller or equal to `targetVectorBitwidth`.
336+
void populateFlattenVectorTransferPatterns(
337+
RewritePatternSet &patterns,
338+
unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
339+
PatternBenefit benefit = 1);
335340

336341
/// Collect a set of patterns that bubble up/down bitcast ops.
337342
///

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2020
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2121
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22-
#include "mlir/IR/BuiltinOps.h"
2322
#include "mlir/IR/Dominance.h"
2423
#include "mlir/Interfaces/SideEffectInterfaces.h"
2524
#include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
535534
/// memref.collapse_shape on the source so that the resulting
536535
/// vector.transfer_read has a 1D source. Requires the source shape to be
537536
/// already reduced i.e. without unit dims.
537+
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
538+
/// the trailing dimension of the vector read is smaller than the provided
539+
/// bitwidth.
538540
class FlattenContiguousRowMajorTransferReadPattern
539541
: public OpRewritePattern<vector::TransferReadOp> {
540-
using OpRewritePattern::OpRewritePattern;
542+
public:
543+
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
544+
unsigned vectorBitwidth,
545+
PatternBenefit benefit)
546+
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
547+
targetVectorBitwidth(vectorBitwidth) {}
541548

542549
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
543550
PatternRewriter &rewriter) const override {
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
554561
// If this is already 0D/1D, there's nothing to do.
555562
if (vectorType.getRank() <= 1)
556563
return failure();
564+
if (!vectorType.getElementType().isSignlessIntOrFloat())
565+
return failure();
566+
unsigned trailingVectorDimBitwidth =
567+
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
568+
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
569+
return failure();
557570
if (!vector::isContiguousSlice(sourceType, vectorType))
558571
return failure();
559572
// TODO: generalize this pattern, relax the requirements here.
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
642655
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
643656
return success();
644657
}
658+
659+
private:
660+
// Minimum bitwidth that the trailing vector dimension should have after
661+
// flattening.
662+
unsigned targetVectorBitwidth;
645663
};
646664

647665
/// Rewrites contiguous row-major vector.transfer_write ops by inserting
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
650668
/// already reduced i.e. without unit dims.
651669
class FlattenContiguousRowMajorTransferWritePattern
652670
: public OpRewritePattern<vector::TransferWriteOp> {
653-
using OpRewritePattern::OpRewritePattern;
671+
public:
672+
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
673+
unsigned vectorBitwidth,
674+
PatternBenefit benefit)
675+
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
676+
targetVectorBitwidth(vectorBitwidth) {}
654677

655678
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
656679
PatternRewriter &rewriter) const override {
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
665688
if (vectorType.getRank() <= 1)
666689
// Already 0D/1D, nothing to do.
667690
return failure();
691+
if (!vectorType.getElementType().isSignlessIntOrFloat())
692+
return failure();
693+
unsigned trailingVectorDimBitwidth =
694+
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
695+
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
696+
return failure();
668697
if (!vector::isContiguousSlice(sourceType, vectorType))
669698
return failure();
670699
int64_t firstContiguousInnerDim =
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
702731
rewriter.eraseOp(transferWriteOp);
703732
return success();
704733
}
734+
735+
private:
736+
// Minimum bitwidth that the trailing vector dimension should have after
737+
// flattening.
738+
unsigned targetVectorBitwidth;
705739
};
706740

707741
/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
917951
}
918952

919953
void mlir::vector::populateFlattenVectorTransferPatterns(
920-
RewritePatternSet &patterns, PatternBenefit benefit) {
954+
RewritePatternSet &patterns, unsigned targetVectorBitwidth,
955+
PatternBenefit benefit) {
921956
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
922957
FlattenContiguousRowMajorTransferWritePattern>(
923-
patterns.getContext(), benefit);
958+
patterns.getContext(), targetVectorBitwidth, benefit);
924959
populateShapeCastFoldingPatterns(patterns, benefit);
925960
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
926961
}

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
23

34
func.func @transfer_read_dims_match_contiguous(
45
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
@@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous(
1617
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
1718
// CHECK: return %[[VEC2D]]
1819

20+
// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
21+
// CHECK-128B: memref.collapse_shape
22+
1923
// -----
2024

2125
func.func @transfer_read_dims_match_contiguous_empty_stride(
@@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
2731
return %v : vector<5x4x3x2xi8>
2832
}
2933

30-
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride
34+
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
3135
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
3236
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
3337
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
3438
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
3539
// CHECK: return %[[VEC2D]]
3640

41+
// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
42+
// CHECK-128B: memref.collapse_shape
43+
3744
// -----
3845

3946
// The shape of the memref and the vector don't match, but the vector is a
@@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous(
5764
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
5865
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
5966

67+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
68+
// CHECK-128B: memref.collapse_shape
69+
6070
// -----
6171

6272
func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -66,7 +76,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
6676
%m_out: memref<1x2x6xi32>) {
6777
%c0 = arith.constant 0 : index
6878
%c0_i32 = arith.constant 0 : i32
69-
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
79+
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
7080
memref<1x43x4x6xi32>, vector<1x2x6xi32>
7181
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
7282
vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
8797
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
8898
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
8999

100+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
101+
// CHECK-128B-NOT: memref.collapse_shape
102+
90103
// -----
91104

92105
// The input memref has a dynamic trailing shape and hence is not flattened.
@@ -99,7 +112,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
99112
%m_out: memref<1x2x6xi32>) {
100113
%c0 = arith.constant 0 : index
101114
%c0_i32 = arith.constant 0 : i32
102-
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
115+
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
103116
memref<1x?x4x6xi32>, vector<1x2x6xi32>
104117
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
105118
vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
115128
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
116129
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
117130

131+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
132+
// CHECK-128B-NOT: memref.collapse_shape
133+
118134
// -----
119135

120136
func.func @transfer_read_dims_mismatch_non_contiguous(
@@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
130146
// CHECK-NOT: memref.collapse_shape
131147
// CHECK-NOT: vector.shape_cast
132148

149+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
150+
// CHECK-128B-NOT: memref.collapse_shape
151+
133152
// -----
134153

135154
func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
@@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
141160
return %v : vector<2x1x2x2xi8>
142161
}
143162

144-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
163+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
145164
// CHECK-NOT: memref.collapse_shape
146165
// CHECK-NOT: vector.shape_cast
147166

167+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
168+
// CHECK-128B-NOT: memref.collapse_shape
169+
148170
// -----
149171

150172
func.func @transfer_write_dims_match_contiguous(
@@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous(
155177
return
156178
}
157179

158-
// CHECK-LABEL: func @transfer_write_dims_match_contiguous
180+
// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
159181
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
160182
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
161183
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
162184
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
163185
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
164186

187+
// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous(
188+
// CHECK-128B: memref.collapse_shape
189+
165190
// -----
166191

167192
func.func @transfer_write_dims_mismatch_contiguous(
@@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous(
182207
// CHECK: return
183208
// CHECK: }
184209

210+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
211+
// CHECK-128B: memref.collapse_shape
212+
185213
// -----
186214

187215
func.func @transfer_write_dims_mismatch_non_contiguous(
@@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
196224
// CHECK-NOT: memref.collapse_shape
197225
// CHECK-NOT: vector.shape_cast
198226

227+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
228+
// CHECK-128B-NOT: memref.collapse_shape
229+
199230
// -----
200231

201232
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
@@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
207238
// CHECK-NOT: memref.collapse_shape
208239
// CHECK-NOT: vector.shape_cast
209240

241+
// CHECK-128B-LABEL: func @transfer_write_0d(
242+
// CHECK-128B-NOT: memref.collapse_shape
243+
// CHECK-128B-NOT: vector.shape_cast
244+
210245
// -----
211246

212247
func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
@@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
219254
// CHECK-NOT: memref.collapse_shape
220255
// CHECK-NOT: vector.shape_cast
221256

257+
// CHECK-128B-LABEL: func @transfer_read_0d(
258+
// CHECK-128B-NOT: memref.collapse_shape
259+
// CHECK-128B-NOT: vector.shape_cast
260+
222261
// -----
223262

224263
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
@@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre
241280
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
242281
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
243282

283+
// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
284+
// CHECK-128B: memref.collapse_shape
285+
244286
// -----
245287

246288
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
@@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
260302
// CHECK-SAME: {in_bounds = [true]}
261303
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
262304

305+
// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
306+
// CHECK-128B: memref.collapse_shape
307+
263308
// -----
264309

265310
func.func @transfer_read_flattenable_negative(
@@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative(
274319
// CHECK-LABEL: func @transfer_read_flattenable_negative
275320
// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
276321

322+
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
323+
// CHECK-128B-NOT: memref.collapse_shape
324+
277325
// -----
278326

279327
func.func @transfer_read_flattenable_negative2(
@@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2(
288336
// CHECK-LABEL: func @transfer_read_flattenable_negative2
289337
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
290338

339+
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
340+
// CHECK-128B-NOT: memref.collapse_shape
341+
291342
// -----
292343

293344
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
@@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
302353
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32>
303354
// CHECK: return %[[VAL_4]] : vector<1x8xi32>
304355

356+
// CHECK-128B-LABEL: func @fold_unit_dim_add_basic(
357+
// CHECK-128B-NOT: memref.collapse_shape
358+
305359
// -----
306360

307361
func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> {
@@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) ->
316370
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32>
317371
// CHECK: return %[[VAL_4]] : vector<1x8x1xi32>
318372

373+
// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing(
374+
// CHECK-128B-NOT: memref.collapse_shape
375+
319376
// -----
320377

321378
func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
@@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>,
334391
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32>
335392
// CHECK: return %[[VAL_4]] : vector<8xi32>
336393

394+
// CHECK-128B-LABEL: func @fold_unit_dim_add(
395+
// CHECK-128B-NOT: memref.collapse_shape
396+
337397
// -----
338398

339399
func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
@@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>,
352412
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32>
353413
// CHECK: return %[[VAL_4]] : vector<8x[2]xf32>
354414

415+
// CHECK-128B-LABEL: func @fold_unit_dim_mulf(
416+
// CHECK-128B-NOT: memref.collapse_shape
417+
355418
// -----
356419

357420
func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> {
@@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32>
367430
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32>
368431
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32>
369432

433+
// CHECK-128B-LABEL: func @fold_unit_dim_sitofp(
434+
// CHECK-128B-NOT: memref.collapse_shape
435+
370436
// -----
371437

372438
// All shape casts are folded away
@@ -389,3 +455,7 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
389455
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
390456
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
391457
// CHECK: return %[[VAL_4]] : vector<8xi32>
458+
459+
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
460+
// CHECK-128B-NOT: memref.collapse_shape
461+

0 commit comments

Comments
 (0)