@@ -496,7 +496,7 @@ class TransferWriteDropUnitDimsPattern
496
496
// /
497
497
// / Note that there might be some restriction on the leading dim of
498
498
// / `VectorType`:
499
- // / 1. if all the trialing dims of `vectorType` match the trailing dims
499
+ // / 1. if all the trailing dims of `vectorType` match the trailing dims
500
500
// / of `memrefType` then the leading dim of `vectorType` can be arbitrary:
501
501
// /
502
502
// / 1.1 contiguous slice, perfect match
@@ -521,7 +521,7 @@ class TransferWriteDropUnitDimsPattern
521
521
// / at strides).
522
522
static bool isContiguousSlice (MemRefType memrefType, VectorType vectorType) {
523
523
524
- // Get the shape of `vectorType`. The leading dim is treated seperately .
524
+ // Get the shape of `vectorType`. The leading dim is treated separately .
525
525
ArrayRef<int64_t > targetShape = vectorType.getShape ();
526
526
auto targetShapeTrailingDims = targetShape.drop_front (1 );
527
527
@@ -531,25 +531,25 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
531
531
if (!succeeded (getStridesAndOffset (memrefType, strides, offset)))
532
532
return false ;
533
533
534
- // Non-unit stride in the trailing dimension means that this is memref is
534
+ // Non-unit stride in the trailing dimension means this memref is
535
535
// not contiguous.
536
536
if (strides.back () != 1 )
537
537
return false ;
538
538
539
- // Do all but the leading dim of `vectorType` and the trailing dims of
540
- // `memrefType` match?
539
+ // Do all but the leading dim of `vectorType` and `memrefType` match?
541
540
bool allTrailingDimsMatch = true ;
542
541
543
542
// The trailing dimension of `memrefType` after collapsing/flattening the
544
543
// current dim. This will be a product of the leading dims, hence initialising
545
544
// to 1.
546
545
int64_t flatDim = 1 ;
547
546
548
- // Iterate overall all dim of `vectorType` excluding the leading dim and
549
- // compare them against the trailing dims of `memrefType`.
547
+ // Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
548
+ // and compare them against the trailing dims of `memrefType`.
550
549
strides.pop_back ();
551
- for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse (llvm::zip (
552
- targetShapeTrailingDims, memrefType.getShape (), strides))) {
550
+ for (auto [targetDim, memrefDim, memrefStride] :
551
+ llvm::reverse (llvm::zip (targetShapeTrailingDims,
552
+ memrefType.getShape ().drop_front (1 ), strides))) {
553
553
flatDim *= memrefDim;
554
554
// If the memref stride does not match the flattened dim, then this is
555
555
// memref is not contiguous.
@@ -564,10 +564,10 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
564
564
allTrailingDimsMatch = (targetDim == memrefDim);
565
565
}
566
566
567
- // If all dims of `vectorType` (excluding the leading dim) match the trailing
568
- // dims `memrefType`, then this is a contiguous load. If there was a
569
- // mismatch, then the internal dims have already been verified to be unit
570
- // dims, but the leading dim still has to be checked.
567
+ // If the trailing dims of `vectorType` and `memrefType` match, then this is a
568
+ // contiguous load. If there was a mismatch, then the internal dims have
569
+ // already been verified to be unit dims, but the leading dim still has to be
570
+ // checked.
571
571
return allTrailingDimsMatch ? true : (targetShape[0 ] == 1 );
572
572
}
573
573
0 commit comments