Skip to content

Commit d592c8e

Browse files
dcaballehanhanW
andauthored
Reapply "[mlir][vector] Drop inner unit dims for transfer ops on dynamic shapes." (#80712) (#81778)
This reverts commit b4c7152. Downstream regression due to another issue that this PR exposes. We have identified the work-items to fix the new issue here: iree-org/iree#16406 Co-authored-by: Han-Chung Wang <[email protected]>
1 parent 9b80ab4 commit d592c8e

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ class DropInnerMostUnitDimsTransferRead
12371237
return failure();
12381238

12391239
auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1240-
if (!srcType || !srcType.hasStaticShape())
1240+
if (!srcType)
12411241
return failure();
12421242

12431243
if (!readOp.getPermutationMap().isMinorIdentity())
@@ -1261,19 +1261,21 @@ class DropInnerMostUnitDimsTransferRead
12611261
targetType.getElementType());
12621262

12631263
auto loc = readOp.getLoc();
1264+
SmallVector<OpFoldResult> sizes =
1265+
memref::getMixedSizes(rewriter, loc, readOp.getSource());
1266+
SmallVector<OpFoldResult> offsets(srcType.getRank(),
1267+
rewriter.getIndexAttr(0));
1268+
SmallVector<OpFoldResult> strides(srcType.getRank(),
1269+
rewriter.getIndexAttr(1));
12641270
MemRefType resultMemrefType =
12651271
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
1266-
SmallVector<int64_t> offsets(srcType.getRank(), 0);
1267-
SmallVector<int64_t> strides(srcType.getRank(), 1);
1268-
12691272
ArrayAttr inBoundsAttr =
12701273
readOp.getInBounds()
12711274
? rewriter.getArrayAttr(
12721275
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
12731276
: ArrayAttr();
12741277
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1275-
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
1276-
strides);
1278+
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
12771279
auto permMap = getTransferMinorIdentityMap(
12781280
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
12791281
Value result = rewriter.create<vector::TransferReadOp>(

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@ func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8,
1616

1717
// -----
1818

19+
func.func @contiguous_outer_dyn_inner_most_view(%in: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
20+
%c0 = arith.constant 0 : index
21+
%cst = arith.constant 0.0 : f32
22+
%0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
23+
return %0 : vector<1x8x1xf32>
24+
}
25+
// CHECK: func @contiguous_outer_dyn_inner_most_view(
26+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
27+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
28+
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
29+
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], 1, 8, 1] [1, 1, 1, 1]
30+
// CHECK-SAME: memref<?x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>
31+
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
32+
// CHECK-SAME: memref<?x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
33+
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
34+
// CHECK: return %[[RESULT]]
35+
36+
// -----
37+
1938
func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
2039
%c0 = arith.constant 0 : index
2140
%f0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)