@@ -3128,7 +3128,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3128
3128
static LogicalResult
3129
3129
verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType,
3130
3130
VectorType vectorType, VectorType maskType,
3131
- AffineMap permutationMap, ArrayAttr inBounds) {
3131
+ VectorType inferredMaskType, AffineMap permutationMap,
3132
+ ArrayAttr inBounds) {
3132
3133
if (op->hasAttr (" masked" )) {
3133
3134
return op->emitOpError (" masked attribute has been removed. "
3134
3135
" Use in_bounds instead." );
@@ -3181,13 +3182,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3181
3182
if (permutationMap.getNumResults () != vectorType.getRank ())
3182
3183
return op->emitOpError (" requires a permutation_map with result dims of "
3183
3184
" the same rank as the vector type" );
3184
-
3185
- VectorType expectedMaskType =
3186
- vector::detail::transferMaskType (vectorType, permutationMap);
3187
- if (maskType && expectedMaskType != maskType)
3188
- return op->emitOpError (" expects mask type consistent with permutation "
3189
- " map: " )
3190
- << maskType;
3191
3185
}
3192
3186
3193
3187
if (permutationMap.getNumSymbols () != 0 )
@@ -3197,6 +3191,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3197
3191
return op->emitOpError (" requires a permutation_map with input dims of the "
3198
3192
" same rank as the source type" );
3199
3193
3194
+ if (maskType && maskType != inferredMaskType)
3195
+ return op->emitOpError (" inferred mask type (" )
3196
+ << inferredMaskType << " ) and mask operand type (" << maskType
3197
+ << " ) don't match" ;
3198
+
3200
3199
if (inBounds) {
3201
3200
if (permutationMap.getNumResults () != static_cast <int64_t >(inBounds.size ()))
3202
3201
return op->emitOpError (" expects the optional in_bounds attr of same rank "
@@ -3239,6 +3238,19 @@ void TransferReadOp::print(OpAsmPrinter &p) {
3239
3238
p << " : " << getShapedType () << " , " << getVectorType ();
3240
3239
}
3241
3240
3241
+ // / Infers the mask type for a transfer read given its vector type and
3242
+ // / permutation map. The mask in a transfer read operation applies to the
3243
+ // / tensor/buffer reading part of it and its type should match the shape read
3244
+ // / *before* any permutation or broadcasting.
3245
+ static VectorType inferTransferReadMaskType (VectorType vecType,
3246
+ AffineMap permMap) {
3247
+ auto i1Type = IntegerType::get (permMap.getContext (), 1 );
3248
+ AffineMap invPermMap = inversePermutation (compressUnusedDims (permMap));
3249
+ assert (invPermMap && " Inversed permutation map couldn't be computed" );
3250
+ SmallVector<int64_t , 8 > maskShape = invPermMap.compose (vecType.getShape ());
3251
+ return VectorType::get (maskShape, i1Type);
3252
+ }
3253
+
3242
3254
ParseResult TransferReadOp::parse (OpAsmParser &parser, OperationState &result) {
3243
3255
auto &builder = parser.getBuilder ();
3244
3256
SMLoc typesLoc;
@@ -3269,13 +3281,14 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3269
3281
VectorType vectorType = types[1 ].dyn_cast <VectorType>();
3270
3282
if (!vectorType)
3271
3283
return parser.emitError (typesLoc, " requires vector type" );
3272
- auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3273
- Attribute mapAttr = result.attributes .get (permutationAttrName);
3274
- if (!mapAttr) {
3275
- auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3276
- // Update `mapAttr` that is used later to determine mask type.
3277
- mapAttr = AffineMapAttr::get (permMap);
3278
- result.attributes .set (permutationAttrName, mapAttr);
3284
+ auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3285
+ Attribute permMapAttr = result.attributes .get (permMapAttrName);
3286
+ AffineMap permMap;
3287
+ if (!permMapAttr) {
3288
+ permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3289
+ result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3290
+ } else {
3291
+ permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3279
3292
}
3280
3293
if (parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
3281
3294
parser.resolveOperands (indexInfo, indexType, result.operands ) ||
@@ -3286,10 +3299,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3286
3299
if (shapedType.getElementType ().dyn_cast <VectorType>())
3287
3300
return parser.emitError (
3288
3301
maskInfo.location , " does not support masks with vector element type" );
3289
- auto map = mapAttr.dyn_cast <AffineMapAttr>().getValue ();
3290
3302
// Instead of adding the mask type as an op type, compute it based on the
3291
3303
// vector type and the permutation map (to keep the type signature small).
3292
- auto maskType = mlir::vector::detail::transferMaskType (vectorType, map );
3304
+ auto maskType = inferTransferReadMaskType (vectorType, permMap );
3293
3305
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3294
3306
return failure ();
3295
3307
}
@@ -3307,13 +3319,17 @@ LogicalResult TransferReadOp::verify() {
3307
3319
VectorType maskType = getMaskType ();
3308
3320
auto paddingType = getPadding ().getType ();
3309
3321
auto permutationMap = getPermutationMap ();
3322
+ VectorType inferredMaskType =
3323
+ maskType ? inferTransferReadMaskType (vectorType, permutationMap)
3324
+ : VectorType ();
3310
3325
auto sourceElementType = shapedType.getElementType ();
3311
3326
3312
3327
if (static_cast <int64_t >(getIndices ().size ()) != shapedType.getRank ())
3313
3328
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
3314
3329
3315
3330
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3316
- shapedType, vectorType, maskType, permutationMap,
3331
+ shapedType, vectorType, maskType,
3332
+ inferredMaskType, permutationMap,
3317
3333
getInBounds () ? *getInBounds () : ArrayAttr ())))
3318
3334
return failure ();
3319
3335
@@ -3677,6 +3693,18 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3677
3693
build (builder, result, vector, dest, indices, permutationMap, inBounds);
3678
3694
}
3679
3695
3696
+ // / Infers the mask type for a transfer write given its vector type and
3697
+ // / permutation map. The mask in a transfer read operation applies to the
3698
+ // / tensor/buffer writing part of it and its type should match the shape written
3699
+ // / *after* any permutation.
3700
+ static VectorType inferTransferWriteMaskType (VectorType vecType,
3701
+ AffineMap permMap) {
3702
+ auto i1Type = IntegerType::get (permMap.getContext (), 1 );
3703
+ SmallVector<int64_t , 8 > maskShape =
3704
+ compressUnusedDims (permMap).compose (vecType.getShape ());
3705
+ return VectorType::get (maskShape, i1Type);
3706
+ }
3707
+
3680
3708
ParseResult TransferWriteOp::parse (OpAsmParser &parser,
3681
3709
OperationState &result) {
3682
3710
auto &builder = parser.getBuilder ();
@@ -3704,11 +3732,14 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3704
3732
ShapedType shapedType = types[1 ].dyn_cast <ShapedType>();
3705
3733
if (!shapedType || !shapedType.isa <MemRefType, RankedTensorType>())
3706
3734
return parser.emitError (typesLoc, " requires memref or ranked tensor type" );
3707
- auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3708
- auto attr = result.attributes .get (permutationAttrName);
3709
- if (!attr) {
3710
- auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3711
- result.attributes .set (permutationAttrName, AffineMapAttr::get (permMap));
3735
+ auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3736
+ auto permMapAttr = result.attributes .get (permMapAttrName);
3737
+ AffineMap permMap;
3738
+ if (!permMapAttr) {
3739
+ permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3740
+ result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3741
+ } else {
3742
+ permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3712
3743
}
3713
3744
if (parser.resolveOperand (vectorInfo, vectorType, result.operands ) ||
3714
3745
parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
@@ -3718,7 +3749,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3718
3749
if (shapedType.getElementType ().dyn_cast <VectorType>())
3719
3750
return parser.emitError (
3720
3751
maskInfo.location , " does not support masks with vector element type" );
3721
- auto maskType = VectorType::get (vectorType. getShape (), builder. getI1Type () );
3752
+ auto maskType = inferTransferWriteMaskType (vectorType, permMap );
3722
3753
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3723
3754
return failure ();
3724
3755
}
@@ -3744,6 +3775,9 @@ LogicalResult TransferWriteOp::verify() {
3744
3775
VectorType vectorType = getVectorType ();
3745
3776
VectorType maskType = getMaskType ();
3746
3777
auto permutationMap = getPermutationMap ();
3778
+ VectorType inferredMaskType =
3779
+ maskType ? inferTransferWriteMaskType (vectorType, permutationMap)
3780
+ : VectorType ();
3747
3781
3748
3782
if (llvm::size (getIndices ()) != shapedType.getRank ())
3749
3783
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
@@ -3754,7 +3788,8 @@ LogicalResult TransferWriteOp::verify() {
3754
3788
return emitOpError (" should not have broadcast dimensions" );
3755
3789
3756
3790
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3757
- shapedType, vectorType, maskType, permutationMap,
3791
+ shapedType, vectorType, maskType,
3792
+ inferredMaskType, permutationMap,
3758
3793
getInBounds () ? *getInBounds () : ArrayAttr ())))
3759
3794
return failure ();
3760
3795
0 commit comments