@@ -1835,9 +1835,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
1835
1835
// / Operates under a scoped context to build the intersection between the
1836
1836
// / view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
1837
1837
// TODO: view intersection/union/differences should be a proper std op.
1838
- static Value createSubViewIntersection (OpBuilder &b,
1839
- VectorTransferOpInterface xferOp,
1840
- Value alloc) {
1838
+ static std::pair<Value, Value> createSubViewIntersection (
1839
+ OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) {
1841
1840
ImplicitLocOpBuilder lb (xferOp.getLoc (), b);
1842
1841
int64_t memrefRank = xferOp.getShapedType ().getRank ();
1843
1842
// TODO: relax this precondition, will require rank-reducing subviews.
@@ -1864,11 +1863,15 @@ static Value createSubViewIntersection(OpBuilder &b,
1864
1863
sizes.push_back (affineMin);
1865
1864
});
1866
1865
1867
- SmallVector<OpFoldResult, 4 > indices = llvm::to_vector<4 >(llvm::map_range (
1866
+ SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4 >(llvm::map_range (
1868
1867
xferOp.indices (), [](Value idx) -> OpFoldResult { return idx; }));
1869
- return lb.create <memref::SubViewOp>(
1870
- isaWrite ? alloc : xferOp.source (), indices, sizes,
1871
- SmallVector<OpFoldResult>(memrefRank, OpBuilder (xferOp).getIndexAttr (1 )));
1868
+ SmallVector<OpFoldResult> destIndices (memrefRank, b.getIndexAttr (0 ));
1869
+ SmallVector<OpFoldResult> strides (memrefRank, b.getIndexAttr (1 ));
1870
+ auto copySrc = lb.create <memref::SubViewOp>(
1871
+ isaWrite ? alloc : xferOp.source (), srcIndices, sizes, strides);
1872
+ auto copyDest = lb.create <memref::SubViewOp>(
1873
+ isaWrite ? xferOp.source () : alloc, destIndices, sizes, strides);
1874
+ return std::make_pair (copySrc, copyDest);
1872
1875
}
1873
1876
1874
1877
// / Given an `xferOp` for which:
@@ -1877,14 +1880,15 @@ static Value createSubViewIntersection(OpBuilder &b,
1877
1880
// / Produce IR resembling:
1878
1881
// / ```
1879
1882
// / %1:3 = scf.if (%inBounds) {
1880
- // / memref.cast %A: memref<A...> to compatibleMemRefType
1883
+ // / %view = memref.cast %A: memref<A...> to compatibleMemRefType
1881
1884
// / scf.yield %view, ... : compatibleMemRefType, index, index
1882
1885
// / } else {
1883
1886
// / %2 = linalg.fill(%pad, %alloc)
1884
1887
// / %3 = subview %view [...][...][...]
1885
- // / linalg.copy(%3, %alloc)
1886
- // / memref.cast %alloc: memref<B...> to compatibleMemRefType
1887
- // / scf.yield %4, ... : compatibleMemRefType, index, index
1888
+ // / %4 = subview %alloc [0, 0] [...] [...]
1889
+ // / linalg.copy(%3, %4)
1890
+ // / %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
1891
+ // / scf.yield %5, ... : compatibleMemRefType, index, index
1888
1892
// / }
1889
1893
// / ```
1890
1894
// / Return the produced scf::IfOp.
@@ -1910,9 +1914,9 @@ createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp,
1910
1914
b.create <linalg::FillOp>(loc, xferOp.padding (), alloc);
1911
1915
// Take partial subview of memref which guarantees no dimension
1912
1916
// overflows.
1913
- Value memRefSubView = createSubViewIntersection (
1917
+ std::pair< Value, Value> copyArgs = createSubViewIntersection (
1914
1918
b, cast<VectorTransferOpInterface>(xferOp.getOperation ()), alloc);
1915
- b.create <linalg::CopyOp>(loc, memRefSubView, alloc );
1919
+ b.create <linalg::CopyOp>(loc, copyArgs. first , copyArgs. second );
1916
1920
Value casted =
1917
1921
b.create <memref::CastOp>(loc, alloc, compatibleMemRefType);
1918
1922
scf::ValueVector viewAndIndices{casted};
@@ -2030,7 +2034,8 @@ getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp,
2030
2034
// / %notInBounds = xor %inBounds, %true
2031
2035
// / scf.if (%notInBounds) {
2032
2036
// / %3 = subview %alloc [...][...][...]
2033
- // / linalg.copy(%3, %view)
2037
+ // / %4 = subview %view [0, 0][...][...]
2038
+ // / linalg.copy(%3, %4)
2034
2039
// / }
2035
2040
// / ```
2036
2041
static void createFullPartialLinalgCopy (OpBuilder &b,
@@ -2040,9 +2045,9 @@ static void createFullPartialLinalgCopy(OpBuilder &b,
2040
2045
auto notInBounds =
2041
2046
lb.create <XOrOp>(inBoundsCond, lb.create <ConstantIntOp>(true , 1 ));
2042
2047
lb.create <scf::IfOp>(notInBounds, [&](OpBuilder &b, Location loc) {
2043
- Value memRefSubView = createSubViewIntersection (
2048
+ std::pair< Value, Value> copyArgs = createSubViewIntersection (
2044
2049
b, cast<VectorTransferOpInterface>(xferOp.getOperation ()), alloc);
2045
- b.create <linalg::CopyOp>(loc, memRefSubView, xferOp. source () );
2050
+ b.create <linalg::CopyOp>(loc, copyArgs. first , copyArgs. second );
2046
2051
b.create <scf::YieldOp>(loc, ValueRange{});
2047
2052
});
2048
2053
}
0 commit comments