-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) #95744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) #95744
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Patch is 30.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95744.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
- dyn_cast<MemRefType>(collapsedSource.getType());
+ cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices =
- getCollapsedIndices(rewriter, loc, sourceType.getShape(),
- transferWriteOp.getIndices(), firstDimToCollapse);
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+ // 2.2 New indices
+ SmallVector<Value> collapsedIndices =
+ getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+ transferWriteOp.getIndices(), firstDimToCollapse);
+
+ // 3. Create new vector.transfer_write that writes to the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+ // 4. Replace the old transfer_write with the new one writing the
+ // collapsed shape
rewriter.eraseOp(transferWriteOp);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..e96c4b785b406 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_read_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
func.func @transfer_read_dims_match_contiguous_empty_stride(
%arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
// contiguous subset of the memref, so "flattenable".
func.func @transfer_read_dims_mismatch_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
- return %v : vector<1x1x2x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
}
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,51 +78,53 @@ func.func @transfer_read_dims_mismatch_contiguous(
// -----
func.func @transfer_read_dims_mismatch_non_zero_indices(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x43x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
// -----
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
- %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index,
+ %idx1 : index) -> vector<2x2xf32> {
+
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f32
- %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+ memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
return %8 : vector<2x2xf32>
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// CHECK-128B: memref.collapse_shape
@@ -125,80 +135,106 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// TODO: This case could be supported via memref.dim
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x?x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+ %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
- %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+ %arg : memref<i8>) -> vector<i8> {
+
+ %cst = arith.constant 0 : i8
+ %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+ return %0 : vector<i8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_0d(
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_write_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
}
// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,68 +244,161 @@ func.func @transfer_write_dims_match_contiguous(
// -----
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_cont...
[truncated]
|
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
40f49e9
to
8aba88b
Compare
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Test coverage seems to be kept and can't find a single NIT.
Thank you for reviewing 🙏🏻 If you have some spare cycles left this week, please also take a look at 1/n: #96031. No worries if you are busy! |
I wish I could, I have nothing against it but, personally I do not have any opinion on the question. I'll let the community decide. |
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and `@transfer_read_flattenable_negative` duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both tests are deleted (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is preserved). 2. `@transfer_read_flattenable_negative2` is replaced with `@transfer_read_non_contiguous_src` and `@transfer_write_non_contiguous_src` (i.e. a dedicated test for xfer_read and xfer_read with more descriptive func names) Depends on llvm#95743. **Only review the top commit.**
8aba88b
to
08c38a8
Compare
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
…m#95744) The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. Below are the main contributions of this PR 1. Two tests duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`: * `@transfer_{read|write}_dims_mismatch_non_contiguous` and * `@transfer_read_flattenable_negative` duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. These tests are removed (the original test is preserved). 2. `@transfer_read_flattenable_negative2` is replaced with two tests with more descriptive names: * `@transfer_read_non_contiguous_src` (for `xfer_read`) and * `@transfer_write_non_contiguous_src` (for `xfer_write`)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
) The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. For consistency with other tests, `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed as `@transfer_read_leading_dynamic_dims`. It is also moved near other tests for `xfer_read`, variable names are updated to match other `xfer_read` tests 2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to better highlight that it's a negative test and to contrast it with `@transfer_read_leading_dynamic_dims` (and to emphasise the difference between the two). 3. Similar changes for tests for `xfer_write`. 4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`). Follow-up for #95743 and #95744
) The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. For consistency with other tests, `@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed as `@transfer_read_leading_dynamic_dims`. It is also moved near other tests for `xfer_read`, variable names are updated to match other `xfer_read` tests 2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to better highlight that it's a negative test and to contrast it with `@transfer_read_leading_dynamic_dims` (and to emphasise the difference between the two). 3. Similar changes for tests for `xfer_write`. 4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`). Follow-up for #95743 and #95744
The main goal of this and subsequent PRs is to unify and categorize
tests in:
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.
Below are the main contributions of this PR
Two tests duplicated
@transfer_{read|write}_dims_mismatch_non_contiguous_slice
:@transfer_{read|write}_dims_mismatch_non_contiguous
and@transfer_read_flattenable_negative
duplicated@transfer_{read|write}_dims_mismatch_non_contiguous_slice
.These tests are removed (the original test is preserved).
@transfer_read_flattenable_negative2
is replaced withtwo tests with more descriptive names:
@transfer_read_non_contiguous_src
(forxfer_read
) and@transfer_write_non_contiguous_src
(forxfer_write
)