@@ -2873,8 +2873,7 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2873
2873
static LogicalResult
2874
2874
verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType,
2875
2875
VectorType vectorType, VectorType maskType,
2876
- VectorType inferredMaskType, AffineMap permutationMap,
2877
- ArrayAttr inBounds) {
2876
+ AffineMap permutationMap, ArrayAttr inBounds) {
2878
2877
if (op->hasAttr (" masked" )) {
2879
2878
return op->emitOpError (" masked attribute has been removed. "
2880
2879
" Use in_bounds instead." );
@@ -2927,6 +2926,13 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2927
2926
if (permutationMap.getNumResults () != vectorType.getRank ())
2928
2927
return op->emitOpError (" requires a permutation_map with result dims of "
2929
2928
" the same rank as the vector type" );
2929
+
2930
+ VectorType expectedMaskType =
2931
+ vector::detail::transferMaskType (vectorType, permutationMap);
2932
+ if (maskType && expectedMaskType != maskType)
2933
+ return op->emitOpError (" expects mask type consistent with permutation "
2934
+ " map: " )
2935
+ << maskType;
2930
2936
}
2931
2937
2932
2938
if (permutationMap.getNumSymbols () != 0 )
@@ -2936,11 +2942,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2936
2942
return op->emitOpError (" requires a permutation_map with input dims of the "
2937
2943
" same rank as the source type" );
2938
2944
2939
- if (maskType && maskType != inferredMaskType)
2940
- return op->emitOpError (" inferred mask type (" )
2941
- << inferredMaskType << " ) and mask operand type (" << maskType
2942
- << " ) don't match" ;
2943
-
2944
2945
if (inBounds) {
2945
2946
if (permutationMap.getNumResults () != static_cast <int64_t >(inBounds.size ()))
2946
2947
return op->emitOpError (" expects the optional in_bounds attr of same rank "
@@ -2983,19 +2984,6 @@ void TransferReadOp::print(OpAsmPrinter &p) {
2983
2984
p << " : " << getShapedType () << " , " << getVectorType ();
2984
2985
}
2985
2986
2986
- // / Infers the mask type for a transfer read given its vector type and
2987
- // / permutation map. The mask in a transfer read operation applies to the
2988
- // / tensor/buffer reading part of it and its type should match the shape read
2989
- // / *before* any permutation or broadcasting.
2990
- static VectorType inferTransferReadMaskType (VectorType vecType,
2991
- AffineMap permMap) {
2992
- auto i1Type = IntegerType::get (permMap.getContext (), 1 );
2993
- AffineMap invPermMap = inversePermutation (compressUnusedDims (permMap));
2994
- assert (invPermMap && " Inversed permutation map couldn't be computed" );
2995
- SmallVector<int64_t , 8 > maskShape = invPermMap.compose (vecType.getShape ());
2996
- return VectorType::get (maskShape, i1Type);
2997
- }
2998
-
2999
2987
ParseResult TransferReadOp::parse (OpAsmParser &parser, OperationState &result) {
3000
2988
auto &builder = parser.getBuilder ();
3001
2989
SMLoc typesLoc;
@@ -3026,14 +3014,13 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3026
3014
VectorType vectorType = types[1 ].dyn_cast <VectorType>();
3027
3015
if (!vectorType)
3028
3016
return parser.emitError (typesLoc, " requires vector type" );
3029
- auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3030
- Attribute permMapAttr = result.attributes .get (permMapAttrName);
3031
- AffineMap permMap;
3032
- if (!permMapAttr) {
3033
- permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3034
- result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3035
- } else {
3036
- permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3017
+ auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3018
+ Attribute mapAttr = result.attributes .get (permutationAttrName);
3019
+ if (!mapAttr) {
3020
+ auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3021
+ // Update `mapAttr` that is used later to determine mask type.
3022
+ mapAttr = AffineMapAttr::get (permMap);
3023
+ result.attributes .set (permutationAttrName, mapAttr);
3037
3024
}
3038
3025
if (parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
3039
3026
parser.resolveOperands (indexInfo, indexType, result.operands ) ||
@@ -3044,9 +3031,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3044
3031
if (shapedType.getElementType ().dyn_cast <VectorType>())
3045
3032
return parser.emitError (
3046
3033
maskInfo.location , " does not support masks with vector element type" );
3034
+ auto map = mapAttr.dyn_cast <AffineMapAttr>().getValue ();
3047
3035
// Instead of adding the mask type as an op type, compute it based on the
3048
3036
// vector type and the permutation map (to keep the type signature small).
3049
- auto maskType = inferTransferReadMaskType (vectorType, permMap );
3037
+ auto maskType = mlir::vector::detail::transferMaskType (vectorType, map );
3050
3038
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3051
3039
return failure ();
3052
3040
}
@@ -3064,17 +3052,13 @@ LogicalResult TransferReadOp::verify() {
3064
3052
VectorType maskType = getMaskType ();
3065
3053
auto paddingType = getPadding ().getType ();
3066
3054
auto permutationMap = getPermutationMap ();
3067
- VectorType inferredMaskType =
3068
- maskType ? inferTransferReadMaskType (vectorType, permutationMap)
3069
- : VectorType ();
3070
3055
auto sourceElementType = shapedType.getElementType ();
3071
3056
3072
3057
if (static_cast <int64_t >(getIndices ().size ()) != shapedType.getRank ())
3073
3058
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
3074
3059
3075
3060
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3076
- shapedType, vectorType, maskType,
3077
- inferredMaskType, permutationMap,
3061
+ shapedType, vectorType, maskType, permutationMap,
3078
3062
getInBounds () ? *getInBounds () : ArrayAttr ())))
3079
3063
return failure ();
3080
3064
@@ -3438,18 +3422,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3438
3422
build (builder, result, vector, dest, indices, permutationMap, inBounds);
3439
3423
}
3440
3424
3441
- // / Infers the mask type for a transfer write given its vector type and
3442
- // / permutation map. The mask in a transfer read operation applies to the
3443
- // / tensor/buffer writing part of it and its type should match the shape written
3444
- // / *after* any permutation.
3445
- static VectorType inferTransferWriteMaskType (VectorType vecType,
3446
- AffineMap permMap) {
3447
- auto i1Type = IntegerType::get (permMap.getContext (), 1 );
3448
- SmallVector<int64_t , 8 > maskShape =
3449
- compressUnusedDims (permMap).compose (vecType.getShape ());
3450
- return VectorType::get (maskShape, i1Type);
3451
- }
3452
-
3453
3425
ParseResult TransferWriteOp::parse (OpAsmParser &parser,
3454
3426
OperationState &result) {
3455
3427
auto &builder = parser.getBuilder ();
@@ -3477,14 +3449,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3477
3449
ShapedType shapedType = types[1 ].dyn_cast <ShapedType>();
3478
3450
if (!shapedType || !shapedType.isa <MemRefType, RankedTensorType>())
3479
3451
return parser.emitError (typesLoc, " requires memref or ranked tensor type" );
3480
- auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3481
- auto permMapAttr = result.attributes .get (permMapAttrName);
3482
- AffineMap permMap;
3483
- if (!permMapAttr) {
3484
- permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3485
- result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3486
- } else {
3487
- permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3452
+ auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3453
+ auto attr = result.attributes .get (permutationAttrName);
3454
+ if (!attr) {
3455
+ auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3456
+ result.attributes .set (permutationAttrName, AffineMapAttr::get (permMap));
3488
3457
}
3489
3458
if (parser.resolveOperand (vectorInfo, vectorType, result.operands ) ||
3490
3459
parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
@@ -3494,7 +3463,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3494
3463
if (shapedType.getElementType ().dyn_cast <VectorType>())
3495
3464
return parser.emitError (
3496
3465
maskInfo.location , " does not support masks with vector element type" );
3497
- auto maskType = inferTransferWriteMaskType (vectorType, permMap );
3466
+ auto maskType = VectorType::get (vectorType. getShape (), builder. getI1Type () );
3498
3467
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3499
3468
return failure ();
3500
3469
}
@@ -3520,9 +3489,6 @@ LogicalResult TransferWriteOp::verify() {
3520
3489
VectorType vectorType = getVectorType ();
3521
3490
VectorType maskType = getMaskType ();
3522
3491
auto permutationMap = getPermutationMap ();
3523
- VectorType inferredMaskType =
3524
- maskType ? inferTransferWriteMaskType (vectorType, permutationMap)
3525
- : VectorType ();
3526
3492
3527
3493
if (llvm::size (getIndices ()) != shapedType.getRank ())
3528
3494
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
@@ -3533,8 +3499,7 @@ LogicalResult TransferWriteOp::verify() {
3533
3499
return emitOpError (" should not have broadcast dimensions" );
3534
3500
3535
3501
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3536
- shapedType, vectorType, maskType,
3537
- inferredMaskType, permutationMap,
3502
+ shapedType, vectorType, maskType, permutationMap,
3538
3503
getInBounds () ? *getInBounds () : ArrayAttr ())))
3539
3504
return failure ();
3540
3505
0 commit comments