Skip to content

Commit e3dabd3

Browse files
committed
fixup! [mlir][Vector] Update patterns for flattening vector.xfer Ops
Update comments
1 parent 9a3c60b commit e3dabd3

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ class TransferWriteDropUnitDimsPattern
496496
///
497497
/// Note that there might be some restriction on the leading dim of
498498
/// `VectorType`:
499-
/// 1. if all the trialing dims of `vectorType` match the trailing dims
499+
/// 1. if all the trailing dims of `vectorType` match the trailing dims
500500
/// of `memrefType` then the leading dim of `vectorType` can be arbitrary:
501501
///
502502
/// 1.1 contiguous slice, perfect match
@@ -521,7 +521,7 @@ class TransferWriteDropUnitDimsPattern
521521
/// at strides).
522522
static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
523523

524-
// Get the shape of `vectorType`. The leading dim is treated seperately.
524+
// Get the shape of `vectorType`. The leading dim is treated separately.
525525
ArrayRef<int64_t> targetShape = vectorType.getShape();
526526
auto targetShapeTrailingDims = targetShape.drop_front(1);
527527

@@ -531,25 +531,25 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
531531
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
532532
return false;
533533

534-
// Non-unit stride in the trailing dimension means that this is memref is
534+
// Non-unit stride in the trailing dimension means this memref is
535535
// not contiguous.
536536
if (strides.back() != 1)
537537
return false;
538538

539-
// Do all but the leading dim of `vectorType` and the trailing dims of
540-
// `memrefType` match?
539+
// Do all but the leading dim of `vectorType` and `memrefType` match?
541540
bool allTrailingDimsMatch = true;
542541

543542
// The trailing dimension of `memrefType` after collapsing/flattening the
544543
// current dim. This will be a product of the leading dims, hence initialising
545544
// to 1.
546545
int64_t flatDim = 1;
547546

548-
// Iterate overall all dim of `vectorType` excluding the leading dim and
549-
// compare them against the trailing dims of `memrefType`.
547+
// Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
548+
// and compare them against the trailing dims of `memrefType`.
550549
strides.pop_back();
551-
for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
552-
targetShapeTrailingDims, memrefType.getShape(), strides))) {
550+
for (auto [targetDim, memrefDim, memrefStride] :
551+
llvm::reverse(llvm::zip(targetShapeTrailingDims,
552+
memrefType.getShape().drop_front(1), strides))) {
553553
flatDim *= memrefDim;
554554
// If the memref stride does not match the flattened dim, then this is
555555
// memref is not contiguous.
@@ -564,10 +564,10 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
564564
allTrailingDimsMatch = (targetDim == memrefDim);
565565
}
566566

567-
// If all dims of `vectorType` (excluding the leading dim) match the trailing
568-
// dims `memrefType`, then this is a contiguous load. If there was a
569-
// mismatch, then the internal dims have already been verified to be unit
570-
// dims, but the leading dim still has to be checked.
567+
// If the trailing dims of `vectorType` and `memrefType` match, then this is a
568+
// contiguous load. If there was a mismatch, then the internal dims have
569+
// already been verified to be unit dims, but the leading dim still has to be
570+
// checked.
571571
return allTrailingDimsMatch ? true : (targetShape[0] == 1);
572572
}
573573

0 commit comments

Comments
 (0)