Skip to content

Commit b4c7152

Browse files
authored
Revert "[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes." (#80712)
Reverts #79752 because it is causing regressions in downstream projects.
1 parent cb8d83a commit b4c7152

File tree

2 files changed

+12
-57
lines changed

2 files changed

+12
-57
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,7 @@ class DropInnerMostUnitDimsTransferRead
12361236
return failure();
12371237

12381238
auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1239-
if (!srcType)
1239+
if (!srcType || !srcType.hasStaticShape())
12401240
return failure();
12411241

12421242
if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1260,21 +1260,19 @@ class DropInnerMostUnitDimsTransferRead
12601260
targetType.getElementType());
12611261

12621262
auto loc = readOp.getLoc();
1263-
SmallVector<OpFoldResult> sizes =
1264-
memref::getMixedSizes(rewriter, loc, readOp.getSource());
1265-
SmallVector<OpFoldResult> offsets(srcType.getRank(),
1266-
rewriter.getIndexAttr(0));
1267-
SmallVector<OpFoldResult> strides(srcType.getRank(),
1268-
rewriter.getIndexAttr(1));
12691263
MemRefType resultMemrefType =
12701264
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
1265+
SmallVector<int64_t> offsets(srcType.getRank(), 0);
1266+
SmallVector<int64_t> strides(srcType.getRank(), 1);
1267+
12711268
ArrayAttr inBoundsAttr =
12721269
readOp.getInBounds()
12731270
? rewriter.getArrayAttr(
12741271
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
12751272
: ArrayAttr();
12761273
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1277-
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1274+
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1275+
strides);
12781276
auto permMap = getTransferMinorIdentityMap(
12791277
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
12801278
Value result = rewriter.create<vector::TransferReadOp>(
@@ -1320,7 +1318,7 @@ class DropInnerMostUnitDimsTransferWrite
13201318
return failure();
13211319

13221320
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1323-
if (!srcType)
1321+
if (!srcType || !srcType.hasStaticShape())
13241322
return failure();
13251323

13261324
if (!writeOp.getPermutationMap().isMinorIdentity())
@@ -1343,23 +1341,20 @@ class DropInnerMostUnitDimsTransferWrite
13431341
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
13441342
targetType.getElementType());
13451343

1346-
Location loc = writeOp.getLoc();
1347-
SmallVector<OpFoldResult> sizes =
1348-
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
1349-
SmallVector<OpFoldResult> offsets(srcType.getRank(),
1350-
rewriter.getIndexAttr(0));
1351-
SmallVector<OpFoldResult> strides(srcType.getRank(),
1352-
rewriter.getIndexAttr(1));
13531344
MemRefType resultMemrefType =
13541345
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
1346+
SmallVector<int64_t> offsets(srcType.getRank(), 0);
1347+
SmallVector<int64_t> strides(srcType.getRank(), 1);
13551348
ArrayAttr inBoundsAttr =
13561349
writeOp.getInBounds()
13571350
? rewriter.getArrayAttr(
13581351
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
13591352
: ArrayAttr();
13601353

1354+
Location loc = writeOp.getLoc();
13611355
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1362-
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1356+
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
1357+
strides);
13631358
auto permMap = getTransferMinorIdentityMap(
13641359
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
13651360

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,6 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
1616

1717
// -----
1818

19-
func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
20-
%c0 = arith.constant 0 : index
21-
%cst = arith.constant 0.0 : f32
22-
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
23-
return %0 : vector<1x8x1xf32>
24-
}
25-
// CHECK: func @contiguous_outer_dyn_inner_most_view(
26-
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
27-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
28-
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
29-
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
30-
// CHECK-SAME: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
31-
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
32-
// CHECK-SAME: memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
33-
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
34-
// CHECK: return %[[RESULT]]
35-
36-
// -----
37-
3819
func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
3920
%c0 = arith.constant 0 : index
4021
%f0 = arith.constant 0.0 : f32
@@ -138,27 +119,6 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
138119

139120
// -----
140121

141-
func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
142-
%c0 = arith.constant 0 : index
143-
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
144-
{in_bounds = [true, true, true, true]}
145-
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
146-
return
147-
}
148-
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
149-
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
150-
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
151-
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
152-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
153-
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
154-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
155-
// CHECK-SAME: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
156-
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
157-
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
158-
// CHECK-SAME: [%[[IDX]], %[[C0]], %[[C0]]]
159-
160-
// -----
161-
162122
func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
163123
%c0 = arith.constant 0 : index
164124
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]

0 commit comments

Comments
 (0)