Skip to content

Commit eb7e299

Browse files
committed
Reland "[mlir][Vector] Re-define masking semantics in vector.transfer ops""
This relands commit 847b5f8. Differential Revision: https://reviews.llvm.org/D138079
1 parent c0321ed commit eb7e299

File tree

10 files changed

+103
-103
lines changed

10 files changed

+103
-103
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,9 +1207,11 @@ def Vector_TransferReadOp :
12071207
provided to specify a fallback value in the case of out-of-bounds accesses
12081208
and/or masking.
12091209

1210-
An optional SSA value `mask` of the same shape as the vector type may be
1211-
specified to mask out elements. Such elements will be replaces with
1212-
`padding`. Elements whose corresponding mask element is `0` are masked out.
1210+
An optional SSA value `mask` may be specified to mask out elements read from
1211+
the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
1212+
matches how elements are read from the MemRef/Tensor, *before* any
1213+
permutation or broadcasting. Elements whose corresponding mask element is
1214+
`0` are masked out and replaced with `padding`.
12131215

12141216
An optional boolean array attribute `in_bounds` specifies for every vector
12151217
dimension if the transfer is guaranteed to be within the source bounds.
@@ -1419,6 +1421,12 @@ def Vector_TransferWriteOp :
14191421

14201422
The size of the slice is specified by the size of the vector.
14211423

1424+
An optional SSA value `mask` may be specified to mask out elements written
1425+
to the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
1426+
matches how elements are written into the MemRef/Tensor, *after* applying
1427+
any permutation. Elements whose corresponding mask element is `0` are
1428+
masked out.
1429+
14221430
An optional SSA value `mask` of the same shape as the vector type may be
14231431
specified to mask out elements. Elements whose corresponding mask element
14241432
is `0` are masked out.

mlir/include/mlir/Interfaces/VectorInterfaces.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,6 @@
1717
#include "mlir/IR/BuiltinTypes.h"
1818
#include "mlir/IR/OpDefinition.h"
1919

20-
namespace mlir {
21-
namespace vector {
22-
namespace detail {
23-
24-
/// Given the vector type and the permutation map of a vector transfer op,
25-
/// compute the expected mask type.
26-
VectorType transferMaskType(VectorType vecType, AffineMap map);
27-
28-
} // namespace detail
29-
} // namespace vector
30-
} // namespace mlir
31-
3220
/// Include the generated interface declarations.
3321
#include "mlir/Interfaces/VectorInterfaces.h.inc"
3422

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,25 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
169169
}]
170170
>,
171171
InterfaceMethod<
172-
/*desc=*/"Return the mask type if the op has a mask.",
172+
/*desc=*/"Return the mask operand if the op has a mask. Otherwise, "
173+
"return a empty value.",
174+
/*retTy=*/"Value",
175+
/*methodName=*/"getMask",
176+
/*args=*/(ins),
177+
/*methodBody=*/"",
178+
/*defaultImplementation=*/[{
179+
return $_op.getMask();
180+
}]
181+
>,
182+
InterfaceMethod<
183+
/*desc=*/"Return the mask type if the op has a mask. Otherwise, return "
184+
"an empty VectorType.",
173185
/*retTy=*/"::mlir::VectorType",
174186
/*methodName=*/"getMaskType",
175187
/*args=*/(ins),
176188
/*methodBody=*/"",
177189
/*defaultImplementation=*/[{
178-
return $_op.getMask()
179-
? ::mlir::vector::detail::transferMaskType(
180-
$_op.getVectorType(), $_op.getPermutationMap())
181-
: ::mlir::VectorType();
190+
return $_op.getMask() ? $_op.getMask().getType() : ::mlir::VectorType();
182191
}]
183192
>,
184193
InterfaceMethod<

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3128,7 +3128,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
31283128
static LogicalResult
31293129
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
31303130
VectorType vectorType, VectorType maskType,
3131-
AffineMap permutationMap, ArrayAttr inBounds) {
3131+
VectorType inferredMaskType, AffineMap permutationMap,
3132+
ArrayAttr inBounds) {
31323133
if (op->hasAttr("masked")) {
31333134
return op->emitOpError("masked attribute has been removed. "
31343135
"Use in_bounds instead.");
@@ -3181,13 +3182,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
31813182
if (permutationMap.getNumResults() != vectorType.getRank())
31823183
return op->emitOpError("requires a permutation_map with result dims of "
31833184
"the same rank as the vector type");
3184-
3185-
VectorType expectedMaskType =
3186-
vector::detail::transferMaskType(vectorType, permutationMap);
3187-
if (maskType && expectedMaskType != maskType)
3188-
return op->emitOpError("expects mask type consistent with permutation "
3189-
"map: ")
3190-
<< maskType;
31913185
}
31923186

31933187
if (permutationMap.getNumSymbols() != 0)
@@ -3197,6 +3191,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
31973191
return op->emitOpError("requires a permutation_map with input dims of the "
31983192
"same rank as the source type");
31993193

3194+
if (maskType && maskType != inferredMaskType)
3195+
return op->emitOpError("inferred mask type (")
3196+
<< inferredMaskType << ") and mask operand type (" << maskType
3197+
<< ") don't match";
3198+
32003199
if (inBounds) {
32013200
if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
32023201
return op->emitOpError("expects the optional in_bounds attr of same rank "
@@ -3239,6 +3238,19 @@ void TransferReadOp::print(OpAsmPrinter &p) {
32393238
p << " : " << getShapedType() << ", " << getVectorType();
32403239
}
32413240

3241+
/// Infers the mask type for a transfer read given its vector type and
3242+
/// permutation map. The mask in a transfer read operation applies to the
3243+
/// tensor/buffer reading part of it and its type should match the shape read
3244+
/// *before* any permutation or broadcasting.
3245+
static VectorType inferTransferReadMaskType(VectorType vecType,
3246+
AffineMap permMap) {
3247+
auto i1Type = IntegerType::get(permMap.getContext(), 1);
3248+
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3249+
assert(invPermMap && "Inversed permutation map couldn't be computed");
3250+
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3251+
return VectorType::get(maskShape, i1Type);
3252+
}
3253+
32423254
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
32433255
auto &builder = parser.getBuilder();
32443256
SMLoc typesLoc;
@@ -3269,13 +3281,14 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
32693281
VectorType vectorType = types[1].dyn_cast<VectorType>();
32703282
if (!vectorType)
32713283
return parser.emitError(typesLoc, "requires vector type");
3272-
auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
3273-
Attribute mapAttr = result.attributes.get(permutationAttrName);
3274-
if (!mapAttr) {
3275-
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3276-
// Update `mapAttr` that is used later to determine mask type.
3277-
mapAttr = AffineMapAttr::get(permMap);
3278-
result.attributes.set(permutationAttrName, mapAttr);
3284+
auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
3285+
Attribute permMapAttr = result.attributes.get(permMapAttrName);
3286+
AffineMap permMap;
3287+
if (!permMapAttr) {
3288+
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3289+
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3290+
} else {
3291+
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
32793292
}
32803293
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
32813294
parser.resolveOperands(indexInfo, indexType, result.operands) ||
@@ -3286,10 +3299,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
32863299
if (shapedType.getElementType().dyn_cast<VectorType>())
32873300
return parser.emitError(
32883301
maskInfo.location, "does not support masks with vector element type");
3289-
auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
32903302
// Instead of adding the mask type as an op type, compute it based on the
32913303
// vector type and the permutation map (to keep the type signature small).
3292-
auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
3304+
auto maskType = inferTransferReadMaskType(vectorType, permMap);
32933305
if (parser.resolveOperand(maskInfo, maskType, result.operands))
32943306
return failure();
32953307
}
@@ -3307,13 +3319,17 @@ LogicalResult TransferReadOp::verify() {
33073319
VectorType maskType = getMaskType();
33083320
auto paddingType = getPadding().getType();
33093321
auto permutationMap = getPermutationMap();
3322+
VectorType inferredMaskType =
3323+
maskType ? inferTransferReadMaskType(vectorType, permutationMap)
3324+
: VectorType();
33103325
auto sourceElementType = shapedType.getElementType();
33113326

33123327
if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
33133328
return emitOpError("requires ") << shapedType.getRank() << " indices";
33143329

33153330
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3316-
shapedType, vectorType, maskType, permutationMap,
3331+
shapedType, vectorType, maskType,
3332+
inferredMaskType, permutationMap,
33173333
getInBounds() ? *getInBounds() : ArrayAttr())))
33183334
return failure();
33193335

@@ -3677,6 +3693,18 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
36773693
build(builder, result, vector, dest, indices, permutationMap, inBounds);
36783694
}
36793695

3696+
/// Infers the mask type for a transfer write given its vector type and
3697+
/// permutation map. The mask in a transfer read operation applies to the
3698+
/// tensor/buffer writing part of it and its type should match the shape written
3699+
/// *after* any permutation.
3700+
static VectorType inferTransferWriteMaskType(VectorType vecType,
3701+
AffineMap permMap) {
3702+
auto i1Type = IntegerType::get(permMap.getContext(), 1);
3703+
SmallVector<int64_t, 8> maskShape =
3704+
compressUnusedDims(permMap).compose(vecType.getShape());
3705+
return VectorType::get(maskShape, i1Type);
3706+
}
3707+
36803708
ParseResult TransferWriteOp::parse(OpAsmParser &parser,
36813709
OperationState &result) {
36823710
auto &builder = parser.getBuilder();
@@ -3704,11 +3732,14 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
37043732
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
37053733
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
37063734
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3707-
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3708-
auto attr = result.attributes.get(permutationAttrName);
3709-
if (!attr) {
3710-
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3711-
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
3735+
auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3736+
auto permMapAttr = result.attributes.get(permMapAttrName);
3737+
AffineMap permMap;
3738+
if (!permMapAttr) {
3739+
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3740+
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3741+
} else {
3742+
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
37123743
}
37133744
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
37143745
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -3718,7 +3749,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
37183749
if (shapedType.getElementType().dyn_cast<VectorType>())
37193750
return parser.emitError(
37203751
maskInfo.location, "does not support masks with vector element type");
3721-
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
3752+
auto maskType = inferTransferWriteMaskType(vectorType, permMap);
37223753
if (parser.resolveOperand(maskInfo, maskType, result.operands))
37233754
return failure();
37243755
}
@@ -3744,6 +3775,9 @@ LogicalResult TransferWriteOp::verify() {
37443775
VectorType vectorType = getVectorType();
37453776
VectorType maskType = getMaskType();
37463777
auto permutationMap = getPermutationMap();
3778+
VectorType inferredMaskType =
3779+
maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
3780+
: VectorType();
37473781

37483782
if (llvm::size(getIndices()) != shapedType.getRank())
37493783
return emitOpError("requires ") << shapedType.getRank() << " indices";
@@ -3754,7 +3788,8 @@ LogicalResult TransferWriteOp::verify() {
37543788
return emitOpError("should not have broadcast dimensions");
37553789

37563790
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3757-
shapedType, vectorType, maskType, permutationMap,
3791+
shapedType, vectorType, maskType,
3792+
inferredMaskType, permutationMap,
37583793
getInBounds() ? *getInBounds() : ArrayAttr())))
37593794
return failure();
37603795

mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,26 +83,6 @@ struct TransferReadPermutationLowering
8383
newVectorShape[pos.value()] = originalShape[pos.index()];
8484
}
8585

86-
// Transpose mask operand.
87-
Value newMask;
88-
if (op.getMask()) {
89-
// Remove unused dims from the permutation map. E.g.:
90-
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91-
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92-
auto comp = compressUnusedDims(map);
93-
// Get positions of remaining result dims.
94-
// E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95-
// maskTransposeIndices = [ 2, 1, 0]
96-
SmallVector<int64_t> maskTransposeIndices;
97-
for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98-
if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99-
maskTransposeIndices.push_back(expr.getPosition());
100-
}
101-
102-
newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103-
maskTransposeIndices);
104-
}
105-
10686
// Transpose in_bounds attribute.
10787
ArrayAttr newInBoundsAttr =
10888
op.getInBounds() ? transposeInBoundsAttr(
@@ -114,7 +94,8 @@ struct TransferReadPermutationLowering
11494
VectorType::get(newVectorShape, op.getVectorType().getElementType());
11595
Value newRead = rewriter.create<vector::TransferReadOp>(
11696
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
117-
AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
97+
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
98+
newInBoundsAttr);
11899

119100
// Transpose result of transfer_read.
120101
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -168,11 +149,6 @@ struct TransferWritePermutationLowering
168149
return expr.dyn_cast<AffineDimExpr>().getPosition();
169150
});
170151

171-
// Transpose mask operand.
172-
Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
173-
op.getLoc(), op.getMask(), indices)
174-
: Value();
175-
176152
// Transpose in_bounds attribute.
177153
ArrayAttr newInBoundsAttr =
178154
op.getInBounds() ? transposeInBoundsAttr(
@@ -186,7 +162,7 @@ struct TransferWritePermutationLowering
186162
map.getNumDims(), map.getNumResults(), rewriter.getContext());
187163
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
188164
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
189-
newMask, newInBoundsAttr);
165+
op.getMask(), newInBoundsAttr);
190166

191167
return success();
192168
}

mlir/lib/Interfaces/VectorInterfaces.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,6 @@
1010

1111
using namespace mlir;
1212

13-
VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
14-
AffineMap map) {
15-
auto i1Type = IntegerType::get(map.getContext(), 1);
16-
SmallVector<int64_t, 8> shape;
17-
for (int64_t i = 0; i < vecType.getRank(); ++i) {
18-
// Only result dims have a corresponding dim in the mask.
19-
if (map.getResult(i).template isa<AffineDimExpr>()) {
20-
shape.push_back(vecType.getDimSize(i));
21-
}
22-
}
23-
return VectorType::get(shape, i1Type);
24-
}
25-
2613
//===----------------------------------------------------------------------===//
2714
// VectorUnroll Interfaces
2815
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66
// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
77
// CHECK-DAG: %[[PADDING:.*]] = arith.constant dense<-4.200000e+01> : vector<9xf32>
8-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<9x4xi1>
8+
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<4x9xi1>
99
// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
10-
// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
11-
// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
10+
// CHECK: memref.store %[[MASK]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
1211
// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
1312
// CHECK: scf.for {{.*}} {
1413
// CHECK: scf.if {{.*}} {
@@ -25,11 +24,10 @@
2524
func.func @transfer_read_2d_mask_transposed(
2625
%A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
2726
%fm42 = arith.constant -42.0: f32
28-
%mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
29-
[1, 1, 1, 1], [0, 1, 1, 0],
30-
[1, 1, 1, 1], [1, 1, 1, 1],
31-
[1, 1, 1, 1], [0, 0, 0, 0],
32-
[1, 1, 1, 1]]> : vector<9x4xi1>
27+
%mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
28+
[0, 0, 1, 1, 1, 1, 1, 0, 1],
29+
[1, 1, 1, 1, 1, 1, 1, 0, 1],
30+
[0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
3331
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
3432
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
3533
memref<?x?xf32>, vector<9x4xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
4949
%v0 = vector.splat %c0 : vector<4x3xi32>
5050
%vi0 = vector.splat %i0 : vector<4x3xindex>
5151
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
52-
%m2 = vector.splat %i1 : vector<5x4xi1>
52+
%m2 = vector.splat %i1 : vector<4x5xi1>
5353
//
5454
// CHECK: vector.transfer_read
5555
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>

0 commit comments

Comments
 (0)