Skip to content

Commit d8e998b

Browse files
committed
fixup! [mlir][Vector] Update patterns for flattening vector.xfer Ops
Final refactor requested by Nicolas
1 parent e3dabd3 commit d8e998b

File tree

4 files changed

+84
-87
lines changed

4 files changed

+84
-87
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
4242
/// on a 2D slice. Otherwise, returns a failure.
4343
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
4444

45+
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
46+
///
47+
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
48+
/// checked (the other dims are not relevant). Note that for `vectorType` to be
49+
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
50+
/// to be contiguous - this is checked by looking at the corresponding strides.
51+
///
52+
/// There might be some restriction on the leading dim of `VectorType`:
53+
///
54+
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
55+
/// of `memrefType` then the leading dim of `vectorType` can be
56+
/// arbitrary.
57+
///
58+
/// Ex. 1.1 contiguous slice, perfect match
59+
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
60+
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
61+
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
62+
///
63+
/// Case 2. If an "internal" dim of `vectorType` does not match the
64+
/// corresponding trailing dim in `memrefType` then the remaining
65+
/// leading dims of `vectorType` have to be 1 (the first non-matching
66+
/// dim can be arbitrary).
67+
///
68+
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
69+
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
70+
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
71+
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
72+
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
73+
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
74+
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
75+
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
76+
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
77+
4578
} // namespace vector
4679

4780
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -487,90 +487,6 @@ class TransferWriteDropUnitDimsPattern
487487

488488
} // namespace
489489

490-
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
491-
///
492-
/// Compares `vectorType` against the trailing dimensions of `memrefType`
493-
/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
494-
/// is implemented by iterating over the dims of `vectorType` and `memrefType`
495-
/// and comparing them starting from the inner-most/right-most dims.
496-
///
497-
/// Note that there might be some restriction on the leading dim of
498-
/// `VectorType`:
499-
/// 1. if all the trailing dims of `vectorType` match the trailing dims
500-
/// of `memrefType` then the leading dim of `vectorType` can be arbitrary:
501-
///
502-
/// 1.1 contiguous slice, perfect match
503-
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
504-
/// 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
505-
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
506-
///
507-
/// 2. if an "internal" dim of `vectorType` does not match the corresponding
508-
/// trailing dim in `memrefType` then the remaining leading dims of
509-
/// `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
510-
///
511-
/// 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
512-
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
513-
/// 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
514-
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
515-
/// 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
516-
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
517-
/// 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
518-
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
519-
///
520-
/// In all cases `memrefType` has to be contiguous (this is checked by looking
521-
/// at strides).
522-
static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
523-
524-
// Get the shape of `vectorType`. The leading dim is treated separately.
525-
ArrayRef<int64_t> targetShape = vectorType.getShape();
526-
auto targetShapeTrailingDims = targetShape.drop_front(1);
527-
528-
// Get the strides of the memref.
529-
int64_t offset;
530-
SmallVector<int64_t> strides;
531-
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
532-
return false;
533-
534-
// Non-unit stride in the trailing dimension means this memref is
535-
// not contiguous.
536-
if (strides.back() != 1)
537-
return false;
538-
539-
// Do all but the leading dim of `vectorType` and `memrefType` match?
540-
bool allTrailingDimsMatch = true;
541-
542-
// The trailing dimension of `memrefType` after collapsing/flattening the
543-
// current dim. This will be a product of the leading dims, hence initialising
544-
// to 1.
545-
int64_t flatDim = 1;
546-
547-
// Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
548-
// and compare them against the trailing dims of `memrefType`.
549-
strides.pop_back();
550-
for (auto [targetDim, memrefDim, memrefStride] :
551-
llvm::reverse(llvm::zip(targetShapeTrailingDims,
552-
memrefType.getShape().drop_front(1), strides))) {
553-
flatDim *= memrefDim;
554-
// If the memref stride does not match the flattened dim, then this is
555-
// memref is not contiguous.
556-
if (flatDim != memrefStride)
557-
return false;
558-
559-
// If a non-matching dim was found previously, then the remaining dims of
560-
// `VectorType` should be 1.
561-
if (!allTrailingDimsMatch && (targetDim != 1))
562-
return false;
563-
564-
allTrailingDimsMatch = (targetDim == memrefDim);
565-
}
566-
567-
// If the trailing dims of `vectorType` and `memrefType` match, then this is a
568-
// contiguous load. If there was a mismatch, then the internal dims have
569-
// already been verified to be unit dims, but the leading dim still has to be
570-
// checked.
571-
return allTrailingDimsMatch ? true : (targetShape[0] == 1);
572-
}
573-
574490
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
575491
/// input starting at `firstDimToCollapse`.
576492
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
@@ -630,7 +546,7 @@ class FlattenContiguousRowMajorTransferReadPattern
630546
if (vectorType.getRank() <= 1)
631547
// Already 0D/1D, nothing to do.
632548
return failure();
633-
if (!isContiguousSlice(sourceType, vectorType))
549+
if (!vector::isContiguousSlice(sourceType, vectorType))
634550
return failure();
635551
int64_t firstContiguousInnerDim =
636552
sourceType.getRank() - vectorType.getRank();
@@ -688,7 +604,7 @@ class FlattenContiguousRowMajorTransferWritePattern
688604
if (vectorType.getRank() <= 1)
689605
// Already 0D/1D, nothing to do.
690606
return failure();
691-
if (!isContiguousSlice(sourceType, vectorType))
607+
if (!vector::isContiguousSlice(sourceType, vectorType))
692608
return failure();
693609
int64_t firstContiguousInnerDim =
694610
sourceType.getRank() - vectorType.getRank();

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,47 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
249249
// between parallel, reduction and possibly other cases.
250250
return ratio.has_value();
251251
}
252+
253+
bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
254+
if (vectorType.isScalable())
255+
return false;
256+
257+
ArrayRef<int64_t> vectorShape = vectorType.getShape();
258+
auto vecRank = vectorType.getRank();
259+
260+
// Extract the trailing dims and strides of the input memref
261+
auto memrefShape = memrefType.getShape().take_back(vecRank);
262+
int64_t offset;
263+
SmallVector<int64_t> stridesFull;
264+
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
265+
return false;
266+
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
267+
268+
// Cond 1: A contiguous memref will always have a unit trailing stride.
269+
if (strides.back() != 1)
270+
return false;
271+
272+
// Cond 2: Strides of a contiguous memref have to match the flattened dims.
273+
strides = strides.drop_back(1);
274+
SmallVector<int64_t> flattenedDims;
275+
for (size_t i = 1; i < memrefShape.size(); i++)
276+
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
277+
278+
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
279+
return false;
280+
281+
// Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
282+
// In the most basic case, all dims will match.
283+
auto firstNonMatchingDim =
284+
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
285+
memrefShape.rbegin(), memrefShape.rend());
286+
if (firstNonMatchingDim.first == vectorShape.rend())
287+
return true;
288+
289+
// One non-matching dim is still fine, however the remaining leading dims of
290+
// `vectorType` need to be 1.
291+
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
292+
vectorShape.rend());
293+
294+
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
295+
}

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
4141

4242
// -----
4343

44-
func.func @transfer_read_dims_mismatch_contiguous(
44+
func.func @transfer_read_dims_mismatch_non_contiguous(
4545
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
4646
%c0 = arith.constant 0 : index
4747
%cst = arith.constant 0 : i8
@@ -50,6 +50,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
5050
return %v : vector<2x1x2x2xi8>
5151
}
5252

53+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
5354
// CHECK-NOT: memref.collapse_shape
5455
// CHECK-NOT: vector.shape_cast
5556

@@ -100,6 +101,7 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
100101
return
101102
}
102103

104+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
103105
// CHECK-NOT: memref.collapse_shape
104106
// CHECK-NOT: vector.shape_cast
105107

@@ -110,6 +112,7 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
110112
return
111113
}
112114

115+
// CHECK-LABEL: func.func @transfer_write_0d
113116
// CHECK-NOT: memref.collapse_shape
114117
// CHECK-NOT: vector.shape_cast
115118

@@ -121,6 +124,7 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
121124
return %0 : vector<i8>
122125
}
123126

127+
// CHECK-LABEL: func.func @transfer_read_0d
124128
// CHECK-NOT: memref.collapse_shape
125129
// CHECK-NOT: vector.shape_cast
126130

0 commit comments

Comments
 (0)