Skip to content

Commit 847b5f8

Browse files
committed
Revert "[mlir][Vector] Re-define masking semantics in vector.transfer ops"
This reverts commit 6c59c5c.
1 parent f51c915 commit 847b5f8

File tree

9 files changed

+98
-99
lines changed

9 files changed

+98
-99
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,11 +1203,9 @@ def Vector_TransferReadOp :
12031203
provided to specify a fallback value in the case of out-of-bounds accesses
12041204
and/or masking.
12051205

1206-
An optional SSA value `mask` may be specified to mask out elements read from
1207-
the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
1208-
matches how elements are read from the MemRef/Tensor, *before* any
1209-
permutation or broadcasting. Elements whose corresponding mask element is
1210-
`0` are masked out and replaced with `padding`.
1206+
An optional SSA value `mask` of the same shape as the vector type may be
1207+
specified to mask out elements. Such elements will be replaces with
1208+
`padding`. Elements whose corresponding mask element is `0` are masked out.
12111209

12121210
An optional boolean array attribute `in_bounds` specifies for every vector
12131211
dimension if the transfer is guaranteed to be within the source bounds.
@@ -1417,12 +1415,6 @@ def Vector_TransferWriteOp :
14171415

14181416
The size of the slice is specified by the size of the vector.
14191417

1420-
An optional SSA value `mask` may be specified to mask out elements written
1421-
to the MemRef/Tensor. The `mask` type is an `i1` vector with a shape that
1422-
matches how elements are written into the MemRef/Tensor, *after* applying
1423-
any permutation. Elements whose corresponding mask element is `0` are
1424-
masked out.
1425-
14261418
An optional SSA value `mask` of the same shape as the vector type may be
14271419
specified to mask out elements. Elements whose corresponding mask element
14281420
is `0` are masked out.

mlir/include/mlir/Interfaces/VectorInterfaces.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@
1717
#include "mlir/IR/BuiltinTypes.h"
1818
#include "mlir/IR/OpDefinition.h"
1919

20+
namespace mlir {
21+
namespace vector {
22+
namespace detail {
23+
24+
/// Given the vector type and the permutation map of a vector transfer op,
25+
/// compute the expected mask type.
26+
VectorType transferMaskType(VectorType vecType, AffineMap map);
27+
28+
} // namespace detail
29+
} // namespace vector
30+
} // namespace mlir
31+
2032
/// Include the generated interface declarations.
2133
#include "mlir/Interfaces/VectorInterfaces.h.inc"
2234

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,25 +169,16 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
169169
}]
170170
>,
171171
InterfaceMethod<
172-
/*desc=*/"Return the mask operand if the op has a mask. Otherwise, "
173-
"return a empty value.",
174-
/*retTy=*/"Value",
175-
/*methodName=*/"getMask",
176-
/*args=*/(ins),
177-
/*methodBody=*/"",
178-
/*defaultImplementation=*/[{
179-
return $_op.getMask();
180-
}]
181-
>,
182-
InterfaceMethod<
183-
/*desc=*/"Return the mask type if the op has a mask. Otherwise, return "
184-
"an empty VectorType.",
172+
/*desc=*/"Return the mask type if the op has a mask.",
185173
/*retTy=*/"::mlir::VectorType",
186174
/*methodName=*/"getMaskType",
187175
/*args=*/(ins),
188176
/*methodBody=*/"",
189177
/*defaultImplementation=*/[{
190-
return $_op.getMask() ? $_op.getMask().getType() : ::mlir::VectorType();
178+
return $_op.getMask()
179+
? ::mlir::vector::detail::transferMaskType(
180+
$_op.getVectorType(), $_op.getPermutationMap())
181+
: ::mlir::VectorType();
191182
}]
192183
>,
193184
InterfaceMethod<

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 25 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2873,8 +2873,7 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
28732873
static LogicalResult
28742874
verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
28752875
VectorType vectorType, VectorType maskType,
2876-
VectorType inferredMaskType, AffineMap permutationMap,
2877-
ArrayAttr inBounds) {
2876+
AffineMap permutationMap, ArrayAttr inBounds) {
28782877
if (op->hasAttr("masked")) {
28792878
return op->emitOpError("masked attribute has been removed. "
28802879
"Use in_bounds instead.");
@@ -2927,6 +2926,13 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
29272926
if (permutationMap.getNumResults() != vectorType.getRank())
29282927
return op->emitOpError("requires a permutation_map with result dims of "
29292928
"the same rank as the vector type");
2929+
2930+
VectorType expectedMaskType =
2931+
vector::detail::transferMaskType(vectorType, permutationMap);
2932+
if (maskType && expectedMaskType != maskType)
2933+
return op->emitOpError("expects mask type consistent with permutation "
2934+
"map: ")
2935+
<< maskType;
29302936
}
29312937

29322938
if (permutationMap.getNumSymbols() != 0)
@@ -2936,11 +2942,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
29362942
return op->emitOpError("requires a permutation_map with input dims of the "
29372943
"same rank as the source type");
29382944

2939-
if (maskType && maskType != inferredMaskType)
2940-
return op->emitOpError("inferred mask type (")
2941-
<< inferredMaskType << ") and mask operand type (" << maskType
2942-
<< ") don't match";
2943-
29442945
if (inBounds) {
29452946
if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
29462947
return op->emitOpError("expects the optional in_bounds attr of same rank "
@@ -2983,19 +2984,6 @@ void TransferReadOp::print(OpAsmPrinter &p) {
29832984
p << " : " << getShapedType() << ", " << getVectorType();
29842985
}
29852986

2986-
/// Infers the mask type for a transfer read given its vector type and
2987-
/// permutation map. The mask in a transfer read operation applies to the
2988-
/// tensor/buffer reading part of it and its type should match the shape read
2989-
/// *before* any permutation or broadcasting.
2990-
static VectorType inferTransferReadMaskType(VectorType vecType,
2991-
AffineMap permMap) {
2992-
auto i1Type = IntegerType::get(permMap.getContext(), 1);
2993-
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
2994-
assert(invPermMap && "Inversed permutation map couldn't be computed");
2995-
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
2996-
return VectorType::get(maskShape, i1Type);
2997-
}
2998-
29992987
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
30002988
auto &builder = parser.getBuilder();
30012989
SMLoc typesLoc;
@@ -3026,14 +3014,13 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
30263014
VectorType vectorType = types[1].dyn_cast<VectorType>();
30273015
if (!vectorType)
30283016
return parser.emitError(typesLoc, "requires vector type");
3029-
auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
3030-
Attribute permMapAttr = result.attributes.get(permMapAttrName);
3031-
AffineMap permMap;
3032-
if (!permMapAttr) {
3033-
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3034-
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3035-
} else {
3036-
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3017+
auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
3018+
Attribute mapAttr = result.attributes.get(permutationAttrName);
3019+
if (!mapAttr) {
3020+
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3021+
// Update `mapAttr` that is used later to determine mask type.
3022+
mapAttr = AffineMapAttr::get(permMap);
3023+
result.attributes.set(permutationAttrName, mapAttr);
30373024
}
30383025
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
30393026
parser.resolveOperands(indexInfo, indexType, result.operands) ||
@@ -3044,9 +3031,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
30443031
if (shapedType.getElementType().dyn_cast<VectorType>())
30453032
return parser.emitError(
30463033
maskInfo.location, "does not support masks with vector element type");
3034+
auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
30473035
// Instead of adding the mask type as an op type, compute it based on the
30483036
// vector type and the permutation map (to keep the type signature small).
3049-
auto maskType = inferTransferReadMaskType(vectorType, permMap);
3037+
auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
30503038
if (parser.resolveOperand(maskInfo, maskType, result.operands))
30513039
return failure();
30523040
}
@@ -3064,17 +3052,13 @@ LogicalResult TransferReadOp::verify() {
30643052
VectorType maskType = getMaskType();
30653053
auto paddingType = getPadding().getType();
30663054
auto permutationMap = getPermutationMap();
3067-
VectorType inferredMaskType =
3068-
maskType ? inferTransferReadMaskType(vectorType, permutationMap)
3069-
: VectorType();
30703055
auto sourceElementType = shapedType.getElementType();
30713056

30723057
if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
30733058
return emitOpError("requires ") << shapedType.getRank() << " indices";
30743059

30753060
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3076-
shapedType, vectorType, maskType,
3077-
inferredMaskType, permutationMap,
3061+
shapedType, vectorType, maskType, permutationMap,
30783062
getInBounds() ? *getInBounds() : ArrayAttr())))
30793063
return failure();
30803064

@@ -3438,18 +3422,6 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
34383422
build(builder, result, vector, dest, indices, permutationMap, inBounds);
34393423
}
34403424

3441-
/// Infers the mask type for a transfer write given its vector type and
3442-
/// permutation map. The mask in a transfer read operation applies to the
3443-
/// tensor/buffer writing part of it and its type should match the shape written
3444-
/// *after* any permutation.
3445-
static VectorType inferTransferWriteMaskType(VectorType vecType,
3446-
AffineMap permMap) {
3447-
auto i1Type = IntegerType::get(permMap.getContext(), 1);
3448-
SmallVector<int64_t, 8> maskShape =
3449-
compressUnusedDims(permMap).compose(vecType.getShape());
3450-
return VectorType::get(maskShape, i1Type);
3451-
}
3452-
34533425
ParseResult TransferWriteOp::parse(OpAsmParser &parser,
34543426
OperationState &result) {
34553427
auto &builder = parser.getBuilder();
@@ -3477,14 +3449,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
34773449
ShapedType shapedType = types[1].dyn_cast<ShapedType>();
34783450
if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
34793451
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3480-
auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3481-
auto permMapAttr = result.attributes.get(permMapAttrName);
3482-
AffineMap permMap;
3483-
if (!permMapAttr) {
3484-
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3485-
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3486-
} else {
3487-
permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3452+
auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3453+
auto attr = result.attributes.get(permutationAttrName);
3454+
if (!attr) {
3455+
auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3456+
result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
34883457
}
34893458
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
34903459
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
@@ -3494,7 +3463,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
34943463
if (shapedType.getElementType().dyn_cast<VectorType>())
34953464
return parser.emitError(
34963465
maskInfo.location, "does not support masks with vector element type");
3497-
auto maskType = inferTransferWriteMaskType(vectorType, permMap);
3466+
auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
34983467
if (parser.resolveOperand(maskInfo, maskType, result.operands))
34993468
return failure();
35003469
}
@@ -3520,9 +3489,6 @@ LogicalResult TransferWriteOp::verify() {
35203489
VectorType vectorType = getVectorType();
35213490
VectorType maskType = getMaskType();
35223491
auto permutationMap = getPermutationMap();
3523-
VectorType inferredMaskType =
3524-
maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
3525-
: VectorType();
35263492

35273493
if (llvm::size(getIndices()) != shapedType.getRank())
35283494
return emitOpError("requires ") << shapedType.getRank() << " indices";
@@ -3533,8 +3499,7 @@ LogicalResult TransferWriteOp::verify() {
35333499
return emitOpError("should not have broadcast dimensions");
35343500

35353501
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3536-
shapedType, vectorType, maskType,
3537-
inferredMaskType, permutationMap,
3502+
shapedType, vectorType, maskType, permutationMap,
35383503
getInBounds() ? *getInBounds() : ArrayAttr())))
35393504
return failure();
35403505

mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ struct TransferReadPermutationLowering
8383
newVectorShape[pos.value()] = originalShape[pos.index()];
8484
}
8585

86+
// Transpose mask operand.
87+
Value newMask;
88+
if (op.getMask()) {
89+
// Remove unused dims from the permutation map. E.g.:
90+
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
91+
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
92+
auto comp = compressUnusedDims(map);
93+
// Get positions of remaining result dims.
94+
// E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
95+
// maskTransposeIndices = [ 2, 1, 0]
96+
SmallVector<int64_t> maskTransposeIndices;
97+
for (unsigned i = 0; i < comp.getNumResults(); ++i) {
98+
if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
99+
maskTransposeIndices.push_back(expr.getPosition());
100+
}
101+
102+
newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.getMask(),
103+
maskTransposeIndices);
104+
}
105+
86106
// Transpose in_bounds attribute.
87107
ArrayAttr newInBoundsAttr =
88108
op.getInBounds() ? transposeInBoundsAttr(
@@ -94,8 +114,7 @@ struct TransferReadPermutationLowering
94114
VectorType::get(newVectorShape, op.getVectorType().getElementType());
95115
Value newRead = rewriter.create<vector::TransferReadOp>(
96116
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
97-
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
98-
newInBoundsAttr);
117+
AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr);
99118

100119
// Transpose result of transfer_read.
101120
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
@@ -149,6 +168,11 @@ struct TransferWritePermutationLowering
149168
return expr.dyn_cast<AffineDimExpr>().getPosition();
150169
});
151170

171+
// Transpose mask operand.
172+
Value newMask = op.getMask() ? rewriter.create<vector::TransposeOp>(
173+
op.getLoc(), op.getMask(), indices)
174+
: Value();
175+
152176
// Transpose in_bounds attribute.
153177
ArrayAttr newInBoundsAttr =
154178
op.getInBounds() ? transposeInBoundsAttr(
@@ -162,7 +186,7 @@ struct TransferWritePermutationLowering
162186
map.getNumDims(), map.getNumResults(), rewriter.getContext());
163187
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
164188
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
165-
op.getMask(), newInBoundsAttr);
189+
newMask, newInBoundsAttr);
166190

167191
return success();
168192
}

mlir/lib/Interfaces/VectorInterfaces.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010

1111
using namespace mlir;
1212

13+
VectorType mlir::vector::detail::transferMaskType(VectorType vecType,
14+
AffineMap map) {
15+
auto i1Type = IntegerType::get(map.getContext(), 1);
16+
SmallVector<int64_t, 8> shape;
17+
for (int64_t i = 0; i < vecType.getRank(); ++i) {
18+
// Only result dims have a corresponding dim in the mask.
19+
if (map.getResult(i).template isa<AffineDimExpr>()) {
20+
shape.push_back(vecType.getDimSize(i));
21+
}
22+
}
23+
return VectorType::get(shape, i1Type);
24+
}
25+
1326
//===----------------------------------------------------------------------===//
1427
// VectorUnroll Interfaces
1528
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToSCF/vector-to-scf-mask-and-permutation-map.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
// CHECK-LABEL: func @transfer_read_2d_mask_transposed(
77
// CHECK-DAG: %[[PADDING:.*]] = arith.constant dense<-4.200000e+01> : vector<9xf32>
8-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<4x9xi1>
8+
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<{{.*}}> : vector<9x4xi1>
99
// CHECK: %[[MASK_MEM:.*]] = memref.alloca() : memref<vector<4x9xi1>>
10-
// CHECK: memref.store %[[MASK]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
10+
// CHECK: %[[MASK_T:.*]] = vector.transpose %[[MASK]], [1, 0] : vector<9x4xi1> to vector<4x9xi1>
11+
// CHECK: memref.store %[[MASK_T]], %[[MASK_MEM]][] : memref<vector<4x9xi1>>
1112
// CHECK: %[[MASK_CASTED:.*]] = vector.type_cast %[[MASK_MEM]] : memref<vector<4x9xi1>> to memref<4xvector<9xi1>>
1213
// CHECK: scf.for {{.*}} {
1314
// CHECK: scf.if {{.*}} {
@@ -24,10 +25,11 @@
2425
func.func @transfer_read_2d_mask_transposed(
2526
%A : memref<?x?xf32>, %base1: index, %base2: index) -> (vector<9x4xf32>) {
2627
%fm42 = arith.constant -42.0: f32
27-
%mask = arith.constant dense<[[1, 0, 1, 0, 1, 1, 1, 0, 1],
28-
[0, 0, 1, 1, 1, 1, 1, 0, 1],
29-
[1, 1, 1, 1, 1, 1, 1, 0, 1],
30-
[0, 0, 1, 0, 1, 1, 1, 0, 1]]> : vector<4x9xi1>
28+
%mask = arith.constant dense<[[1, 0, 1, 0], [0, 0, 1, 0],
29+
[1, 1, 1, 1], [0, 1, 1, 0],
30+
[1, 1, 1, 1], [1, 1, 1, 1],
31+
[1, 1, 1, 1], [0, 0, 0, 0],
32+
[1, 1, 1, 1]]> : vector<9x4xi1>
3133
%f = vector.transfer_read %A[%base1, %base2], %fm42, %mask
3234
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} :
3335
memref<?x?xf32>, vector<9x4xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
4949
%v0 = vector.splat %c0 : vector<4x3xi32>
5050
%vi0 = vector.splat %i0 : vector<4x3xindex>
5151
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
52-
%m2 = vector.splat %i1 : vector<4x5xi1>
52+
%m2 = vector.splat %i1 : vector<5x4xi1>
5353
//
5454
// CHECK: vector.transfer_read
5555
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>

0 commit comments

Comments
 (0)