-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) #73523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) #73523
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesUpdates patterns for flattening %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32> Previously only the following case would be consider for collapsing: %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
memref<1x43x4x6xi32>, vector<1x2x6xi32> The pattern itself,
Similar update for Full diff: https://github.com/llvm/llvm-project/pull/73523.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..951a378b84cf0e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,76 @@ class TransferWriteDropUnitDimsPattern
} // namespace
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
- ArrayRef<int64_t> targetShape) {
- auto shape = memrefType.getShape();
- SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType,
+ VectorType vectorType) {
+
+ ArrayRef<int64_t> targetShape = vectorType.getShape();
+ auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+ // Not used
int64_t offset;
+ SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
return false;
+
+ // Non-unit stride in the trailing dimension means that this is memref is
+ // not contiguous.
if (strides.back() != 1)
return false;
- strides.pop_back();
+
+ // Do all but the leading dim of `vectorType` and the trailing dims of
+ // `memrefType` match?
+ bool allTrailingDimsMatch = true;
+
+ // The trailing dimension of `memrefType` after collapsing/flattening the
+ // current dim. This will be a product of the leading dims, hence initialising
+ // to 1.
int64_t flatDim = 1;
- for (auto [targetDim, memrefDim, memrefStride] :
- llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+ strides.pop_back();
+ for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+ targetShapeTrailingDims, memrefType.getShape(), strides))) {
flatDim *= memrefDim;
- if (flatDim != memrefStride || targetDim != memrefDim)
+ // If the memref stride does not match the flattened dim, then this is
+ // memref is not contiguous.
+ if (flatDim != memrefStride)
+ return false;
+
+ // If a non-matching dim was found, then the remaining dims of `VectorType`
+ // should be 1.
+ if (!allTrailingDimsMatch && (targetDim != 1))
return false;
+
+ allTrailingDimsMatch = (targetDim == memrefDim);
}
- return true;
+
+ return allTrailingDimsMatch ? true : (targetShape[0] == 1);
}
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -529,6 +579,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
/// Checks that the indices corresponding to dimensions starting at
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+/// TODO: Extract the logic that writes to outIndices so that this method
+/// simply checks one pre-condition.
static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
SmallVector<Value> &outIndices) {
@@ -562,18 +614,16 @@ class FlattenContiguousRowMajorTransferReadPattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
- // Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstContiguousInnerDim =
- sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
@@ -581,26 +631,76 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
if (transferReadOp.getMask())
return failure();
+
SmallVector<Value> collapsedIndices;
- if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
- firstContiguousInnerDim,
- collapsedIndices)))
- return failure();
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+ // 1. Collapse the source memref
Value collapsedSource =
- collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+ collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
- assert(collapsedRank == firstContiguousInnerDim + 1);
+ assert(collapsedRank == firstDimToCollapse + 1);
+
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
- getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+ getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+ // 2.2 New indices
+ // If all the collapsed indices are zero then no extra logic is needed.
+ // Otherwise, a new offset/index has to be computed.
+ if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+ firstDimToCollapse,
+ collapsedIndices))) {
+ // Copy all the leading indices
+ collapsedIndices = transferReadOp.getIndices();
+ collapsedIndices.resize(firstDimToCollapse);
+
+ // Compute the remaining trailing index/offset required for reading from
+ // the collapsed memref:
+ //
+ // offset = 0
+ // for (i = firstDimToCollapse; i < outputRank; ++i)
+ // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+ //
+ // For this example:
+ // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+ // memref<1x43x2xi32>, vector<1x2xi32>
+ // which would be collapsed to:
+ // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+ // memref<1x86xi32>, vector<2xi32>
+ // one would get the following offset:
+ // %offset = %arg0 * 43
+ int64_t outputRank = transferReadOp.getIndices().size();
+ Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+ Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+ auto sourceDimSize =
+ rewriter.create<memref::DimOp>(loc, source, dimIdx);
+
+ offset = rewriter.create<arith::AddIOp>(
+ loc,
+ rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
+ sourceDimSize),
+ offset);
+ }
+ collapsedIndices.push_back(offset);
+ }
+
+ // 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+ // 4. Replace the old transfer_read with the new one reading from the
+ // collapsed shape
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
@@ -628,9 +728,7 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d055a..8369069e31ab7c6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,76 @@ func.func @transfer_read_flattenable_with_offset(
// -----
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_non_zero_indices(
+ %idx_1: index,
+ %idx_2: index,
+ %m_in: memref<1x43x4x6xi32>,
+ %m_out: memref<1x2x6xi32>) {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ memref<1x43x4x6xi32>, vector<1x2x6xi32>
+ vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x6xi32>, memref<1x2x6xi32>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME: %[[VAL_3:.*]]: memref<1x2x6xi32>) {
+// CHECK: %[[VAL_4:.*]] = arith.constant 43 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
+// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
+// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
+// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK: vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +95,7 @@ func.func @transfer_write_flattenable_with_offset(
return
}
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +104,46 @@ func.func @transfer_write_flattenable_with_offset(
// -----
+func.func @transfer_write_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
return
}
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK-SAME: %[[VEC:.+]]: vector<i8>
-// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK: return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
@@ -54,11 +153,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
return %0 : vector<i8>
}
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK: %[[CST:.+]] = arith.constant 0 : i8
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK: return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
|
Depends on #73522 - please only review the top commit 🙏🏻 . |
cc810a7
to
fecd909
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
I'm not a legitimate reviewer here but I'll vouch for the usefulness of this change. Flattening patterns are often key to consistently good codegen of transfer ops and this PR seems to remove an unnecessary limitation. |
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
Updates patterns for flattening vector.transfer_read by relaxing the requirement that the "collapsed" indices are all zero. This enables collapsing cases like this one: ```mlir %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` Previously only the following case would be consider for collapsing: ```mlir %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... : memref<1x43x4x6xi32>, vector<1x2x6xi32> ``` The pattern itself, `FlattenContiguousRowMajorTransferReadPattern`, was a bit refactored too: * added comments, * renamed `firstContiguousInnerDim` as `firstDimToCollapse` (the latter better matches the meaning and is already consistently used in various helper methods that use it), Similar update for `vector.transfer_write` will be implemented in a follow-up patch.
…(2/N) Refactor to use makeComposedFoldedAffineApply
fecd909
to
b27c49d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing on this and being patient with my review comments!
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am about to send a small update - it addresses comments from @hanhanW and also restricts the "rewrite" added here. If there are no new comments, I will land it tomorrow.
Thank you for taking a look :)
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
It works both ways - thank your for bearing with me :) And for excellent comments - really helped to improve this patch (same comment for my previous PR) 🙏🏻 |
…(2/N) Address comments from @hanhanW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, thank you!
Updates patterns for flattening
vector.transfer_read
by relaxing therequirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:
Previously only the following case would be consider for collapsing:
Also adds some new comments and renames the
firstContiguousInnerDim
parameteras
firstDimToCollapse
(the latter better matches the actual meaning).Similar updates for
vector.transfer_write
will be implemented in afollow-up patch.