Skip to content

Commit 4a876b1

Browse files
committed
Add case to handle 0-D vectors in FlattenContiguousRowMajorTransferWritePattern and FlattenContiguousRowMajorTransferReadPattern.
For 0-D as well as 1-D vectors, both these patterns should return a failure as there is no need to collapse the shape of the source. Currently, only 1-D vectors were handled. This patch handles the 0-D case as well. Reviewed By: Benoit, ThomasRaoux Differential Revision: https://reviews.llvm.org/D119202
1 parent 079b6d0 commit 4a876b1

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ class FlattenContiguousRowMajorTransferReadPattern
373373
// Contiguity check is valid on tensors only.
374374
if (!sourceType)
375375
return failure();
376-
if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
377-
// Already 1D, nothing to do.
376+
if (vectorType.getRank() <= 1)
377+
// Already 0D/1D, nothing to do.
378378
return failure();
379379
if (!isStaticShapeAndContiguousRowMajor(sourceType))
380380
return failure();
@@ -425,8 +425,8 @@ class FlattenContiguousRowMajorTransferWritePattern
425425
// Contiguity check is valid on tensors only.
426426
if (!sourceType)
427427
return failure();
428-
if (vectorType.getRank() == 1 && sourceType.getRank() == 1)
429-
// Already 1D, nothing to do.
428+
if (vectorType.getRank() <= 1)
429+
// Already 0D/1D, nothing to do.
430430
return failure();
431431
if (!isStaticShapeAndContiguousRowMajor(sourceType))
432432
return failure();

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,29 @@ func @transfer_write_flattenable_with_offset(
3333
// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
3434
// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
3535

36+
// -----
37+
38+
func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
39+
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
40+
return
41+
}
42+
43+
// CHECK-LABEL: func @transfer_write_0d
44+
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
45+
// CHECK-SAME: %[[VEC:.+]]: vector<i8>
46+
// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
47+
// CHECK: return
48+
49+
// -----
50+
51+
func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
52+
%cst = arith.constant 0 : i8
53+
%0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
54+
return %0 : vector<i8>
55+
}
56+
57+
// CHECK-LABEL: func @transfer_read_0d
58+
// CHECK-SAME: %[[ARG:.+]]: memref<i8>
59+
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
60+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
61+
// CHECK: return %[[READ]]

0 commit comments

Comments
 (0)