Skip to content

Commit 9a3c60b

Browse files
committed
fixup! [mlir][Vector] Update patterns for flattening vector.xfer Ops
Update comments
1 parent 902ccc3 commit 9a3c60b

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -489,37 +489,43 @@ class TransferWriteDropUnitDimsPattern
489489

490490
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
491491
///
492-
/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
493-
/// to check whether `vectorType` is a contiguous slice of `memrefType`.
492+
/// Compares `vectorType` against the trailing dimensions of `memrefType`
493+
/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
494+
/// is implemented by iterating over the dims of `vectorType` and `memrefType`
495+
/// and comparing them starting from the inner-most/right-most dims.
494496
///
495-
/// There are two cases:
497+
/// Note that there might be some restriction on the leading dim of
498+
/// `VectorType`:
499+
/// 1. if all the trialing dims of `vectorType` match the trailing dims
500+
/// of `memrefType` then the leading dim of `vectorType` can be arbitrary:
496501
///
497-
/// 1. The trailing dimensions of `memrefType` match the dimensions of
498-
/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
499-
/// not matter in this case):
502+
/// 1.1 contiguous slice, perfect match
503+
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
504+
/// 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
505+
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
500506
///
501-
/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
502-
/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
507+
/// 2. if an "internal" dim of `vectorType` does not match the corresponding
508+
/// trailing dim in `memrefType` then the remaining leading dims of
509+
/// `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
503510
///
504-
/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
505-
/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
506-
/// first dim of `vectorType` that does not match can be arbitrary, but the
507-
/// remaining leading dims have to be 1:
511+
/// 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
512+
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
513+
/// 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
514+
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
515+
/// 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
516+
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
517+
/// 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
518+
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
508519
///
509-
/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
510-
/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
511-
///
512-
/// In both cases `memrefType` has to be contiguous (this is checked by looking
520+
/// In all cases `memrefType` has to be contiguous (this is checked by looking
513521
/// at strides).
514-
///
515-
/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
516-
/// TODO: Update
517522
static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
518523

524+
// Get the shape of `vectorType`. The leading dim is treated seperately.
519525
ArrayRef<int64_t> targetShape = vectorType.getShape();
520526
auto targetShapeTrailingDims = targetShape.drop_front(1);
521527

522-
// Not used
528+
// Get the strides of the memref.
523529
int64_t offset;
524530
SmallVector<int64_t> strides;
525531
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
@@ -538,6 +544,9 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
538544
// current dim. This will be a product of the leading dims, hence initialising
539545
// to 1.
540546
int64_t flatDim = 1;
547+
548+
// Iterate overall all dim of `vectorType` excluding the leading dim and
549+
// compare them against the trailing dims of `memrefType`.
541550
strides.pop_back();
542551
for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
543552
targetShapeTrailingDims, memrefType.getShape(), strides))) {
@@ -547,14 +556,18 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
547556
if (flatDim != memrefStride)
548557
return false;
549558

550-
// If a non-matching dim was found, then the remaining dims of `VectorType`
551-
// should be 1.
559+
// If a non-matching dim was found previously, then the remaining dims of
560+
// `VectorType` should be 1.
552561
if (!allTrailingDimsMatch && (targetDim != 1))
553562
return false;
554563

555564
allTrailingDimsMatch = (targetDim == memrefDim);
556565
}
557566

567+
// If all dims of `vectorType` (excluding the leading dim) match the trailing
568+
// dims `memrefType`, then this is a contiguous load. If there was a
569+
// mismatch, then the internal dims have already been verified to be unit
570+
// dims, but the leading dim still has to be checked.
558571
return allTrailingDimsMatch ? true : (targetShape[0] == 1);
559572
}
560573

0 commit comments

Comments
 (0)