Skip to content

Commit 8c2233b

Browse files
authored
[mlir][vector] Update docs + add tests (#137144)
This is a small follow-on for #133721: * Renamed `getRealVectorRank` as `getEffectiveVectorRankForXferOp` (to emphasise that this method was written specifically for transfer Ops). * Marginally tweaked the description for `getEffectiveVectorRankForXferOp` (mostly to highlight the two edge cases being covered). * Added tests for cases when the element type (of the shaped type) is a vector. * Unified the naming (and the order) of arguments in tests with the surrounding tests (e.g. `%vec_to_write` -> `%arg1`). Mostly for consistency (it would be good to use self-documenting names like `%vec_to_write` throughout).
1 parent 71329c6 commit 8c2233b

File tree

2 files changed

+41
-18
lines changed

2 files changed

+41
-18
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151151
return false;
152152
}
153153

154-
/// Returns the number of dimensions of the `shapedType` that participate in the
155-
/// vector transfer, effectively the rank of the vector dimensions within the
156-
/// `shapedType`. This is calculated by taking the rank of the `vectorType`
157-
/// being transferred and subtracting the rank of the `shapedType`'s element
158-
/// type if it's also a vector.
154+
/// Returns the effective rank of the vector to read/write for Xfer Ops
159155
///
160-
/// This is used to determine the number of minor dimensions for identity maps
161-
/// in vector transfers.
156+
/// When the element type of the shaped type is _a scalar_, this will simply
157+
/// return the rank of the vector ( the result for xfer_read or the value to
158+
/// store for xfer_write).
162159
///
163-
/// For example, given a transfer operation involving `shapedType` and
164-
/// `vectorType`:
160+
/// When the element type of the base shaped type is _a vector_, returns the
161+
/// difference between the original vector type and the element type of the
162+
/// shaped type.
165163
///
164+
/// EXAMPLE 1 (element type is _a scalar_):
166165
/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167166
/// - shapedType.getElementType() = f32 (rank 0)
168167
/// - vectorType.getRank() = 2
169168
/// - Result = 2 - 0 = 2
170169
///
170+
/// EXAMPLE 2 (element type is _a vector_):
171171
/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172172
/// - shapedType.getElementType() = vector<20xf32> (rank 1)
173173
/// - vectorType.getRank() = 1
174174
/// - Result = 1 - 1 = 0
175-
static unsigned getRealVectorRank(ShapedType shapedType,
176-
VectorType vectorType) {
175+
///
176+
/// This is used to determine the number of minor dimensions for identity maps
177+
/// in vector transfer Ops.
178+
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
179+
VectorType vectorType) {
177180
unsigned elementVectorRank = 0;
178181
VectorType elementVectorType =
179182
llvm::dyn_cast<VectorType>(shapedType.getElementType());
@@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
192195
/*numDims=*/0, /*numSymbols=*/0,
193196
getAffineConstantExpr(0, shapedType.getContext()));
194197
return AffineMap::getMinorIdentityMap(
195-
shapedType.getRank(), getRealVectorRank(shapedType, vectorType),
198+
shapedType.getRank(),
199+
getEffectiveVectorRankForXferOp(shapedType, vectorType),
196200
shapedType.getContext());
197201
}
198202

@@ -4261,7 +4265,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
42614265
Attribute permMapAttr = result.attributes.get(permMapAttrName);
42624266
AffineMap permMap;
42634267
if (!permMapAttr) {
4264-
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
4268+
if (shapedType.getRank() <
4269+
getEffectiveVectorRankForXferOp(shapedType, vectorType))
42654270
return parser.emitError(typesLoc,
42664271
"expected a custom permutation_map when "
42674272
"rank(source) != rank(destination)");
@@ -4680,7 +4685,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
46804685
auto permMapAttr = result.attributes.get(permMapAttrName);
46814686
AffineMap permMap;
46824687
if (!permMapAttr) {
4683-
if (shapedType.getRank() < getRealVectorRank(shapedType, vectorType))
4688+
if (shapedType.getRank() <
4689+
getEffectiveVectorRankForXferOp(shapedType, vectorType))
46844690
return parser.emitError(typesLoc,
46854691
"expected a custom permutation_map when "
46864692
"rank(source) != rank(destination)");

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,24 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
525525

526526
// -----
527527

528-
func.func @test_vector.transfer_read(%arg1: memref<?xindex>) -> vector<3x4xindex> {
528+
func.func @test_vector.transfer_read(%arg0: memref<?xindex>) -> vector<3x4xindex> {
529529
%c3 = arith.constant 3 : index
530530
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
531-
%0 = vector.transfer_read %arg1[%c3, %c3], %c3 : memref<?xindex>, vector<3x4xindex>
531+
%0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xindex>, vector<3x4xindex>
532532
return %0 : vector<3x4xindex>
533533
}
534534

535535
// -----
536536

537+
func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>) {
538+
%c3 = arith.constant 3 : index
539+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
540+
%0 = vector.transfer_read %arg0[%c3], %c3 : memref<?xvector<2xindex>>, vector<2x3x4xindex>
541+
return %0 : vector<2x3x4xindex>
542+
}
543+
544+
// -----
545+
537546
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
538547
%c3 = arith.constant 3 : index
539548
%cst = arith.constant 3.0 : f32
@@ -655,10 +664,18 @@ func.func @test_vector.transfer_write(%arg0: memref<?xf32>, %arg1: vector<7xf32>
655664

656665
// -----
657666

658-
func.func @test_vector.transfer_write(%vec_to_write: vector<3x4xindex>, %output_memref: memref<?xindex>) {
667+
func.func @test_vector.transfer_write(%arg0: memref<?xindex>, %arg1: vector<3x4xindex>) {
668+
%c3 = arith.constant 3 : index
669+
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
670+
vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
671+
}
672+
673+
// -----
674+
675+
func.func @test_vector.transfer_write(%arg0: memref<?xvector<2xindex>>, %arg1: vector<2x3x4xindex>) {
659676
%c3 = arith.constant 3 : index
660677
// expected-error@+1 {{expected a custom permutation_map when rank(source) != rank(destination)}}
661-
vector.transfer_write %vec_to_write, %output_memref[%c3, %c3] : vector<3x4xindex>, memref<?xindex>
678+
vector.transfer_write %arg1, %arg0[%c3, %c3] : vector<2x3x4xindex>, memref<?xvector<2xindex>>
662679
}
663680

664681
// -----

0 commit comments

Comments
 (0)