@@ -489,37 +489,43 @@ class TransferWriteDropUnitDimsPattern
489
489
490
490
// / Return true if `vectorType` is a contiguous slice of `memrefType`.
491
491
// /
492
- // / Compares `vectorType` against the trailing dimensions (*) of `memrefType`
493
- // / to check whether `vectorType` is a contiguous slice of `memrefType`.
492
+ // / Compares `vectorType` against the trailing dimensions of `memrefType`
493
+ // / to check whether `vectorType` is a contiguous slice of `memrefType`. This
494
+ // / is implemented by iterating over the dims of `vectorType` and `memrefType`
495
+ // / and comparing them starting from the inner-most/right-most dims.
494
496
// /
495
- // / There are two cases:
497
+ // / Note that there might be some restriction on the leading dim of
498
+ // / `VectorType`:
499
+ // / 1. if all the trialing dims of `vectorType` match the trailing dims
500
+ // / of `memrefType` then the leading dim of `vectorType` can be arbitrary:
496
501
// /
497
- // / 1. The trailing dimensions of `memrefType` match the dimensions of
498
- // / `vectorType` excluding the front dim (the leading dim of `vectorType` does
499
- // / not matter in this case):
502
+ // / 1.1 contiguous slice, perfect match
503
+ // / vector<4x3x2xi32> from memref<5x4x3x2xi32>
504
+ // / 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
505
+ // / vector<2x3x2xi32> from memref<5x4x3x2xi32>
500
506
// /
501
- // / vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
502
- // / vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
507
+ // / 2. if an "internal" dim of `vectorType` does not match the corresponding
508
+ // / trailing dim in `memrefType` then the remaining leading dims of
509
+ // / `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
503
510
// /
504
- // / 2. The trailing dimension of `memrefType` match the trailing dimensions of
505
- // / `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
506
- // / first dim of `vectorType` that does not match can be arbitrary, but the
507
- // / remaining leading dims have to be 1:
511
+ // / 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
512
+ // / vector<2x2x2xi32> from memref<5x4x3x2xi32>
513
+ // / 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
514
+ // / vector<1x2x2xi32> from memref<5x4x3x2xi32>
515
+ // / 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
516
+ // / vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
517
+ // / 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
518
+ // / vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
508
519
// /
509
- // / vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
510
- // / vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
511
- // /
512
- // / In both cases `memrefType` has to be contiguous (this is checked by looking
520
+ // / In all cases `memrefType` has to be contiguous (this is checked by looking
513
521
// / at strides).
514
- // /
515
- // / (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
516
- // / TODO: Update
517
522
static bool isContiguousSlice (MemRefType memrefType, VectorType vectorType) {
518
523
524
+ // Get the shape of `vectorType`. The leading dim is treated seperately.
519
525
ArrayRef<int64_t > targetShape = vectorType.getShape ();
520
526
auto targetShapeTrailingDims = targetShape.drop_front (1 );
521
527
522
- // Not used
528
+ // Get the strides of the memref.
523
529
int64_t offset;
524
530
SmallVector<int64_t > strides;
525
531
if (!succeeded (getStridesAndOffset (memrefType, strides, offset)))
@@ -538,6 +544,9 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
538
544
// current dim. This will be a product of the leading dims, hence initialising
539
545
// to 1.
540
546
int64_t flatDim = 1 ;
547
+
548
+ // Iterate overall all dim of `vectorType` excluding the leading dim and
549
+ // compare them against the trailing dims of `memrefType`.
541
550
strides.pop_back ();
542
551
for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse (llvm::zip (
543
552
targetShapeTrailingDims, memrefType.getShape (), strides))) {
@@ -547,14 +556,18 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
547
556
if (flatDim != memrefStride)
548
557
return false ;
549
558
550
- // If a non-matching dim was found, then the remaining dims of `VectorType`
551
- // should be 1.
559
+ // If a non-matching dim was found previously , then the remaining dims of
560
+ // `VectorType` should be 1.
552
561
if (!allTrailingDimsMatch && (targetDim != 1 ))
553
562
return false ;
554
563
555
564
allTrailingDimsMatch = (targetDim == memrefDim);
556
565
}
557
566
567
+ // If all dims of `vectorType` (excluding the leading dim) match the trailing
568
+ // dims `memrefType`, then this is a contiguous load. If there was a
569
+ // mismatch, then the internal dims have already been verified to be unit
570
+ // dims, but the leading dim still has to be checked.
558
571
return allTrailingDimsMatch ? true : (targetShape[0 ] == 1 );
559
572
}
560
573
0 commit comments