Skip to content

Commit ffdf0a3

Browse files
[mlir][vector] Fix bug in vector-transfer-full-partial-split
When splitting with linalg.copy, cannot write into the destination alloc directly. Instead, write into a subview of the alloc. Differential Revision: https://reviews.llvm.org/D110512
1 parent 683e506 commit ffdf0a3

File tree

2 files changed

+31
-22
lines changed

2 files changed

+31
-22
lines changed

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,9 +1835,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
18351835
/// Operates under a scoped context to build the intersection between the
18361836
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
18371837
// TODO: view intersection/union/differences should be a proper std op.
1838-
static Value createSubViewIntersection(OpBuilder &b,
1839-
VectorTransferOpInterface xferOp,
1840-
Value alloc) {
1838+
static std::pair<Value, Value> createSubViewIntersection(
1839+
OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
18411840
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
18421841
int64_t memrefRank = xferOp.getShapedType().getRank();
18431842
// TODO: relax this precondition, will require rank-reducing subviews.
@@ -1864,11 +1863,15 @@ static Value createSubViewIntersection(OpBuilder &b,
18641863
sizes.push_back(affineMin);
18651864
});
18661865

1867-
SmallVector<OpFoldResult, 4> indices = llvm::to_vector<4>(llvm::map_range(
1866+
SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
18681867
xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; }));
1869-
return lb.create<memref::SubViewOp>(
1870-
isaWrite ? alloc : xferOp.source(), indices, sizes,
1871-
SmallVector<OpFoldResult>(memrefRank, OpBuilder(xferOp).getIndexAttr(1)));
1868+
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
1869+
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
1870+
auto copySrc = lb.create<memref::SubViewOp>(
1871+
isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides);
1872+
auto copyDest = lb.create<memref::SubViewOp>(
1873+
isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides);
1874+
return std::make_pair(copySrc, copyDest);
18721875
}
18731876

18741877
/// Given an `xferOp` for which:
@@ -1877,14 +1880,15 @@ static Value createSubViewIntersection(OpBuilder &b,
18771880
/// Produce IR resembling:
18781881
/// ```
18791882
/// %1:3 = scf.if (%inBounds) {
1880-
/// memref.cast %A: memref<A...> to compatibleMemRefType
1883+
/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
18811884
/// scf.yield %view, ... : compatibleMemRefType, index, index
18821885
/// } else {
18831886
/// %2 = linalg.fill(%pad, %alloc)
18841887
/// %3 = subview %view [...][...][...]
1885-
/// linalg.copy(%3, %alloc)
1886-
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
1887-
/// scf.yield %4, ... : compatibleMemRefType, index, index
1888+
/// %4 = subview %alloc [0, 0] [...] [...]
1889+
/// linalg.copy(%3, %4)
1890+
/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
1891+
/// scf.yield %5, ... : compatibleMemRefType, index, index
18881892
/// }
18891893
/// ```
18901894
/// Return the produced scf::IfOp.
@@ -1910,9 +1914,9 @@ createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
19101914
b.create<linalg::FillOp>(loc, xferOp.padding(), alloc);
19111915
// Take partial subview of memref which guarantees no dimension
19121916
// overflows.
1913-
Value memRefSubView = createSubViewIntersection(
1917+
std::pair<Value, Value> copyArgs = createSubViewIntersection(
19141918
b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
1915-
b.create<linalg::CopyOp>(loc, memRefSubView, alloc);
1919+
b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
19161920
Value casted =
19171921
b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
19181922
scf::ValueVector viewAndIndices{casted};
@@ -2030,7 +2034,8 @@ getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
20302034
/// %notInBounds = xor %inBounds, %true
20312035
/// scf.if (%notInBounds) {
20322036
/// %3 = subview %alloc [...][...][...]
2033-
/// linalg.copy(%3, %view)
2037+
/// %4 = subview %view [0, 0][...][...]
2038+
/// linalg.copy(%3, %4)
20342039
/// }
20352040
/// ```
20362041
static void createFullPartialLinalgCopy(OpBuilder &b,
@@ -2040,9 +2045,9 @@ static void createFullPartialLinalgCopy(OpBuilder &b,
20402045
auto notInBounds =
20412046
lb.create<XOrOp>(inBoundsCond, lb.create<ConstantIntOp>(true, 1));
20422047
lb.create<scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
2043-
Value memRefSubView = createSubViewIntersection(
2048+
std::pair<Value, Value> copyArgs = createSubViewIntersection(
20442049
b, cast<VectorTransferOpInterface>(xferOp.getOperation()), alloc);
2045-
b.create<linalg::CopyOp>(loc, memRefSubView, xferOp.source());
2050+
b.create<linalg::CopyOp>(loc, copyArgs.first, copyArgs.second);
20462051
b.create<scf::YieldOp>(loc, ValueRange{});
20472052
});
20482053
}

mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -
8181
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
8282
// LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
8383
// LINALG-SAME: memref<?x8xf32> to memref<?x?xf32, #[[$map_2d_stride_8x1]]>
84-
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<4x8xf32>
84+
// LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
85+
// LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_8x1]]>, memref<?x?xf32, #{{.*}}>
8586
// LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] :
8687
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32>
8788
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@@ -172,7 +173,8 @@ func @split_vector_transfer_read_strided_2d(
172173
// LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
173174
// LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
174175
// LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref<?x?xf32, #[[$map_2d_stride_1]]>
175-
// LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<4x8xf32>
176+
// LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
177+
// LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref<?x?xf32, #[[$map_2d_stride_1]]>, memref<?x?xf32, #{{.*}}>
176178
// LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] :
177179
// LINALG-SAME: memref<4x8xf32> to memref<?x8xf32, #[[$map_2d_stride_1]]>
178180
// LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
@@ -276,8 +278,9 @@ func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %
276278
// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
277279
// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
278280
// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP4]]>
279-
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]])
280-
// LINALG-SAME: : memref<?x?xf32, #[[MAP4]]>, memref<?x8xf32>
281+
// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
282+
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
283+
// LINALG-SAME: : memref<?x?xf32, #[[MAP4]]>, memref<?x?xf32, #{{.*}}>
281284
// LINALG: }
282285
// LINALG: return
283286
// LINALG: }
@@ -384,8 +387,9 @@ func @split_vector_transfer_write_strided_2d(
384387
// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
385388
// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
386389
// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, #[[MAP5]]>
387-
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]])
388-
// LINALG-SAME: : memref<?x?xf32, #[[MAP5]]>, memref<7x8xf32, #[[MAP0]]>
390+
// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
391+
// LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]])
392+
// LINALG-SAME: : memref<?x?xf32, #[[MAP5]]>, memref<?x?xf32, #[[MAP0]]>
389393
// LINALG: }
390394
// LINALG: return
391395
// LINALG: }

0 commit comments

Comments
 (0)