Skip to content

Commit e6e55e6

Browse files
[mlir][vector] Fix off-by-one error in getTransferChunkAccessed (llvm#70292)
If a dimension does not appear in the permutation map of a vector transfer op, the size of the accessed slice in that dimension is `1`. Before this fix, `getTransferChunkAccessed` used to return `0` for such dimensions, which would means that `0` elements in the underlying tensor/memref are accessed. Note: There is no test case that fails due to this bug and because this interface method is currently only used in one place, it is hard to write a regression test. This fix is in preparation of subset hoisting functionality that will be added in subsequent commits.
1 parent 5270df3 commit e6e55e6

File tree

3 files changed

+48
-23
lines changed

3 files changed

+48
-23
lines changed

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -257,22 +257,22 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
257257
>,
258258
InterfaceMethod<
259259
/*desc=*/[{
260-
Return an upper-bound shape accessed by the transfer op within the
261-
tensor/memref operand.
260+
Return the shape of the hyperrectangular slice within the tensor/memref
261+
operand that is accessed by the transfer op.
262262
For example:
263263
```
264-
vector.transfer %w0[%i, %j] {
265-
permutation_map = affine_map<(d0, d1) -> (d1, d0, 0)>} :
266-
tensor<?x?xf32>, vector<4x2x6xf32>
264+
vector.transfer %w0[%i, %j, %k] {
265+
permutation_map = affine_map<(d0, d1, d2) -> (d1, d0, 0)>} :
266+
tensor<?x?x?xf32>, vector<4x2x6xf32>
267267
```
268-
returns a shape [2, 4].
268+
returns a shape [2, 4, 1].
269269
}],
270270
/*retTy=*/"SmallVector<int64_t>",
271271
/*methodName=*/"getTransferChunkAccessed",
272272
/*args=*/(ins),
273273
/*methodBody=*/"",
274274
/*defaultImplementation=*/[{
275-
SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 0);
275+
SmallVector<int64_t> dimSizes($_op.getPermutationMap().getNumDims(), 1);
276276
for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(),
277277
$_op.getVectorType().getShape())) {
278278
AffineExpr dim = std::get<0>(vecDims);

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,35 +4004,36 @@ struct TransferReadAfterWriteToBroadcast
40044004
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
40054005
if (!defWrite)
40064006
return failure();
4007-
4008-
SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
4009-
Value vec;
4010-
if (readOp.getIndices() == defWrite.getIndices() &&
4011-
readOp.getMask() == defWrite.getMask()) {
4012-
SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
4013-
// TODO: If the writeDim is a superset of the read dims we could do an
4014-
// extract_strided_slice.
4015-
if (writeDims == readDims)
4016-
vec = defWrite.getVector();
4017-
}
4007+
// TODO: If the written transfer chunk is a superset of the read transfer
4008+
// chunk we could do an extract_strided_slice.
4009+
if (readOp.getTransferChunkAccessed() !=
4010+
defWrite.getTransferChunkAccessed())
4011+
return failure();
4012+
// TODO: Support cases where a dim is explicitly written but implicitly
4013+
// read (i.e., a unit dim that is rank reduced).
4014+
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
4015+
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
4016+
return failure();
4017+
if (readOp.getIndices() != defWrite.getIndices() ||
4018+
readOp.getMask() != defWrite.getMask())
4019+
return failure();
4020+
Value vec = defWrite.getVector();
40184021
// TODO: loop through the chain of transfer_write if we can prove that they
40194022
// don't overlap with the transfer_read. This requires improving
40204023
// `isDisjointTransferIndices` helper.
4021-
if (!vec)
4022-
return failure();
4023-
SmallVector<unsigned> permutation;
40244024
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
40254025
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
40264026
AffineMap map = readMap.compose(writeMap);
40274027
if (map.getNumResults() == 0)
40284028
return failure();
4029-
// Calculate the permuation to apply to go from the vector stored to the
4029+
// Calculate the permutation to apply to go from the vector stored to the
40304030
// vector read.
4031+
SmallVector<unsigned> permutation;
40314032
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
40324033
return failure();
40334034

40344035
Location loc = readOp.getLoc();
4035-
// Calculate the broadcast shape by applying the reverse permuation to the
4036+
// Calculate the broadcast shape by applying the reverse permutation to the
40364037
// final shape we want.
40374038
ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
40384039
SmallVector<int64_t> broadcastShape(destShape.size());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,3 +2400,27 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
24002400
%2 = vector.shape_cast %1 : vector<4x1x1xi1> to vector<4xi1>
24012401
return %2 : vector<4xi1>
24022402
}
2403+
2404+
// -----
2405+
2406+
// TODO: This IR could be canonicalized but the canonicalization pattern is not
2407+
// smart enough. For now, just make sure that we do not crash.
2408+
2409+
// CHECK-LABEL: func.func @load_store_forwarding_rank_mismatch(
2410+
// CHECK: vector.transfer_write
2411+
// CHECK: vector.transfer_read
2412+
func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: tensor<4x4x4xf32>) -> (vector<1x100x4x5xf32>) {
2413+
%c0 = arith.constant 0 : index
2414+
%cf0 = arith.constant 0.0 : f32
2415+
// d0 is explicitly written.
2416+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
2417+
{in_bounds = [true, true, true],
2418+
permutation_map = affine_map<(d0, d1, d2) -> (d2, d1, d0)>} :
2419+
vector<4x1x1xf32>, tensor<4x4x4xf32>
2420+
// d0 is implicitly read (rank-reduction of unit dim).
2421+
%r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
2422+
{in_bounds = [true, true, true, true],
2423+
permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
2424+
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
2425+
return %r : vector<1x100x4x5xf32>
2426+
}

0 commit comments

Comments
 (0)