@@ -487,90 +487,6 @@ class TransferWriteDropUnitDimsPattern
487
487
488
488
} // namespace
489
489
490
- // / Return true if `vectorType` is a contiguous slice of `memrefType`.
491
- // /
492
- // / Compares `vectorType` against the trailing dimensions of `memrefType`
493
- // / to check whether `vectorType` is a contiguous slice of `memrefType`. This
494
- // / is implemented by iterating over the dims of `vectorType` and `memrefType`
495
- // / and comparing them starting from the inner-most/right-most dims.
496
- // /
497
- // / Note that there might be some restriction on the leading dim of
498
- // / `VectorType`:
499
- // / 1. if all the trailing dims of `vectorType` match the trailing dims
500
- // / of `memrefType` then the leading dim of `vectorType` can be arbitrary:
501
- // /
502
- // / 1.1 contiguous slice, perfect match
503
- // / vector<4x3x2xi32> from memref<5x4x3x2xi32>
504
- // / 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
505
- // / vector<2x3x2xi32> from memref<5x4x3x2xi32>
506
- // /
507
- // / 2. if an "internal" dim of `vectorType` does not match the corresponding
508
- // / trailing dim in `memrefType` then the remaining leading dims of
509
- // / `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
510
- // /
511
- // / 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
512
- // / vector<2x2x2xi32> from memref<5x4x3x2xi32>
513
- // / 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
514
- // / vector<1x2x2xi32> from memref<5x4x3x2xi32>
515
- // / 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
516
- // / vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
517
- // / 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
518
- // / vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
519
- // /
520
- // / In all cases `memrefType` has to be contiguous (this is checked by looking
521
- // / at strides).
522
- static bool isContiguousSlice (MemRefType memrefType, VectorType vectorType) {
523
-
524
- // Get the shape of `vectorType`. The leading dim is treated separately.
525
- ArrayRef<int64_t > targetShape = vectorType.getShape ();
526
- auto targetShapeTrailingDims = targetShape.drop_front (1 );
527
-
528
- // Get the strides of the memref.
529
- int64_t offset;
530
- SmallVector<int64_t > strides;
531
- if (!succeeded (getStridesAndOffset (memrefType, strides, offset)))
532
- return false ;
533
-
534
- // Non-unit stride in the trailing dimension means this memref is
535
- // not contiguous.
536
- if (strides.back () != 1 )
537
- return false ;
538
-
539
- // Do all but the leading dim of `vectorType` and `memrefType` match?
540
- bool allTrailingDimsMatch = true ;
541
-
542
- // The trailing dimension of `memrefType` after collapsing/flattening the
543
- // current dim. This will be a product of the leading dims, hence initialising
544
- // to 1.
545
- int64_t flatDim = 1 ;
546
-
547
- // Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
548
- // and compare them against the trailing dims of `memrefType`.
549
- strides.pop_back ();
550
- for (auto [targetDim, memrefDim, memrefStride] :
551
- llvm::reverse (llvm::zip (targetShapeTrailingDims,
552
- memrefType.getShape ().drop_front (1 ), strides))) {
553
- flatDim *= memrefDim;
554
- // If the memref stride does not match the flattened dim, then this is
555
- // memref is not contiguous.
556
- if (flatDim != memrefStride)
557
- return false ;
558
-
559
- // If a non-matching dim was found previously, then the remaining dims of
560
- // `VectorType` should be 1.
561
- if (!allTrailingDimsMatch && (targetDim != 1 ))
562
- return false ;
563
-
564
- allTrailingDimsMatch = (targetDim == memrefDim);
565
- }
566
-
567
- // If the trailing dims of `vectorType` and `memrefType` match, then this is a
568
- // contiguous load. If there was a mismatch, then the internal dims have
569
- // already been verified to be unit dims, but the leading dim still has to be
570
- // checked.
571
- return allTrailingDimsMatch ? true : (targetShape[0 ] == 1 );
572
- }
573
-
574
490
// / Creates a memref.collapse_shape collapsing all inner dimensions of the
575
491
// / input starting at `firstDimToCollapse`.
576
492
static Value collapseInnerDims (PatternRewriter &rewriter, mlir::Location loc,
@@ -630,7 +546,7 @@ class FlattenContiguousRowMajorTransferReadPattern
630
546
if (vectorType.getRank () <= 1 )
631
547
// Already 0D/1D, nothing to do.
632
548
return failure ();
633
- if (!isContiguousSlice (sourceType, vectorType))
549
+ if (!vector:: isContiguousSlice (sourceType, vectorType))
634
550
return failure ();
635
551
int64_t firstContiguousInnerDim =
636
552
sourceType.getRank () - vectorType.getRank ();
@@ -688,7 +604,7 @@ class FlattenContiguousRowMajorTransferWritePattern
688
604
if (vectorType.getRank () <= 1 )
689
605
// Already 0D/1D, nothing to do.
690
606
return failure ();
691
- if (!isContiguousSlice (sourceType, vectorType))
607
+ if (!vector:: isContiguousSlice (sourceType, vectorType))
692
608
return failure ();
693
609
int64_t firstContiguousInnerDim =
694
610
sourceType.getRank () - vectorType.getRank ();
0 commit comments