@@ -151,29 +151,32 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
151
151
return false ;
152
152
}
153
153
154
- // / Returns the number of dimensions of the `shapedType` that participate in the
155
- // / vector transfer, effectively the rank of the vector dimensions within the
156
- // / `shapedType`. This is calculated by taking the rank of the `vectorType`
157
- // / being transferred and subtracting the rank of the `shapedType`'s element
158
- // / type if it's also a vector.
154
+ // / Returns the effective rank of the vector to read/write for Xfer Ops
159
155
// /
160
- // / This is used to determine the number of minor dimensions for identity maps
161
- // / in vector transfers.
156
+ // / When the element type of the shaped type is _a scalar_, this will simply
157
+ // / return the rank of the vector ( the result for xfer_read or the value to
158
+ // / store for xfer_write).
162
159
// /
163
- // / For example, given a transfer operation involving `shapedType` and
164
- // / `vectorType`:
160
+ // / When the element type of the base shaped type is _a vector_, returns the
161
+ // / difference between the original vector type and the element type of the
162
+ // / shaped type.
165
163
// /
164
+ // / EXAMPLE 1 (element type is _a scalar_):
166
165
// / - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167
166
// / - shapedType.getElementType() = f32 (rank 0)
168
167
// / - vectorType.getRank() = 2
169
168
// / - Result = 2 - 0 = 2
170
169
// /
170
+ // / EXAMPLE 2 (element type is _a vector_):
171
171
// / - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
172
172
// / - shapedType.getElementType() = vector<20xf32> (rank 1)
173
173
// / - vectorType.getRank() = 1
174
174
// / - Result = 1 - 1 = 0
175
- static unsigned getRealVectorRank (ShapedType shapedType,
176
- VectorType vectorType) {
175
+ // /
176
+ // / This is used to determine the number of minor dimensions for identity maps
177
+ // / in vector transfer Ops.
178
+ static unsigned getEffectiveVectorRankForXferOp (ShapedType shapedType,
179
+ VectorType vectorType) {
177
180
unsigned elementVectorRank = 0 ;
178
181
VectorType elementVectorType =
179
182
llvm::dyn_cast<VectorType>(shapedType.getElementType ());
@@ -192,7 +195,8 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
192
195
/* numDims=*/ 0 , /* numSymbols=*/ 0 ,
193
196
getAffineConstantExpr (0 , shapedType.getContext ()));
194
197
return AffineMap::getMinorIdentityMap (
195
- shapedType.getRank (), getRealVectorRank (shapedType, vectorType),
198
+ shapedType.getRank (),
199
+ getEffectiveVectorRankForXferOp (shapedType, vectorType),
196
200
shapedType.getContext ());
197
201
}
198
202
@@ -4261,7 +4265,8 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4261
4265
Attribute permMapAttr = result.attributes .get (permMapAttrName);
4262
4266
AffineMap permMap;
4263
4267
if (!permMapAttr) {
4264
- if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4268
+ if (shapedType.getRank () <
4269
+ getEffectiveVectorRankForXferOp (shapedType, vectorType))
4265
4270
return parser.emitError (typesLoc,
4266
4271
" expected a custom permutation_map when "
4267
4272
" rank(source) != rank(destination)" );
@@ -4680,7 +4685,8 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
4680
4685
auto permMapAttr = result.attributes .get (permMapAttrName);
4681
4686
AffineMap permMap;
4682
4687
if (!permMapAttr) {
4683
- if (shapedType.getRank () < getRealVectorRank (shapedType, vectorType))
4688
+ if (shapedType.getRank () <
4689
+ getEffectiveVectorRankForXferOp (shapedType, vectorType))
4684
4690
return parser.emitError (typesLoc,
4685
4691
" expected a custom permutation_map when "
4686
4692
" rank(source) != rank(destination)" );
0 commit comments