Skip to content

Commit 12b676d

Browse files
authored
[mlir][vector] Drop innermost unit dims on transfer_write. (#78554)
1 parent 2c78f3b commit 12b676d

File tree

2 files changed

+218
-54
lines changed

2 files changed

+218
-54
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 165 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
11521152
}
11531153
};
11541154

1155-
// Drop inner most contiguous unit dimensions from transfer_read operand.
1156-
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
1155+
/// Returns the number of dims can be folded away from transfer ops. It returns
1156+
/// a failure if it can not determine the number of dims to be folded.
1157+
/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1158+
/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1159+
/// can be dropped by memref.subview ops.
1160+
/// Example 2: it returns "1" if `srcType` is the same memref type with
1161+
/// [8192, 16, 8, 1] strides.
1162+
static FailureOr<size_t>
1163+
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1164+
SmallVector<int64_t> srcStrides;
1165+
int64_t srcOffset;
1166+
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1167+
return failure();
1168+
1169+
// According to vector.transfer_read/write semantics, the vector can be a
1170+
// slice. Thus, we have to offset the check index with `rankDiff` in
1171+
// `srcStrides` and source dim sizes.
1172+
size_t result = 0;
1173+
int rankDiff = srcType.getRank() - vectorType.getRank();
1174+
for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1175+
// Check that the inner dim size is 1 for both memref type and vector slice.
1176+
// It can be folded only if they are 1 and the stride is 1.
1177+
int dim = vectorType.getRank() - i - 1;
1178+
if (srcStrides[dim + rankDiff] != 1 ||
1179+
srcType.getDimSize(dim + rankDiff) != 1 ||
1180+
vectorType.getDimSize(dim) != 1)
1181+
break;
1182+
result++;
1183+
}
1184+
return result;
1185+
}
1186+
1187+
/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
1188+
/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
1189+
/// two, it returns memref<512x16x16> type.
1190+
static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
1191+
MemRefType srcType,
1192+
size_t dimsToDrop) {
1193+
MemRefType resultMemrefType;
1194+
MemRefLayoutAttrInterface layout = srcType.getLayout();
1195+
if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
1196+
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
1197+
srcType.getElementType(), nullptr,
1198+
srcType.getMemorySpace());
1199+
}
1200+
MemRefLayoutAttrInterface updatedLayout;
1201+
if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1202+
auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
1203+
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
1204+
strided.getOffset(), strides);
1205+
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
1206+
srcType.getElementType(), updatedLayout,
1207+
srcType.getMemorySpace());
1208+
}
1209+
1210+
// Non-strided layout case.
1211+
AffineMap map = srcType.getLayout().getAffineMap();
1212+
int numSymbols = map.getNumSymbols();
1213+
for (size_t i = 0; i < dimsToDrop; ++i) {
1214+
int dim = srcType.getRank() - i - 1;
1215+
map = map.replace(builder.getAffineDimExpr(dim),
1216+
builder.getAffineConstantExpr(0), map.getNumDims() - 1,
1217+
numSymbols);
1218+
}
1219+
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
1220+
srcType.getElementType(), updatedLayout,
1221+
srcType.getMemorySpace());
1222+
}
1223+
1224+
/// Drop inner most contiguous unit dimensions from transfer_read operand.
1225+
class DropInnerMostUnitDimsTransferRead
1226+
: public OpRewritePattern<vector::TransferReadOp> {
11571227
using OpRewritePattern::OpRewritePattern;
11581228

11591229
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,65 +1247,22 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
11771247
if (targetType.getRank() <= 1)
11781248
return failure();
11791249

1180-
SmallVector<int64_t> srcStrides;
1181-
int64_t srcOffset;
1182-
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1183-
return failure();
1184-
1185-
// According to vector.transfer_read semantics, the result can be a slice.
1186-
// It pads the indices with `1` starting from beginning. Thus, we have to
1187-
// offset the check index with `rankDiff` in `srcStrides` and source dim
1188-
// sizes.
1189-
size_t dimsToDrop = 0;
1190-
int rankDiff = srcType.getRank() - targetType.getRank();
1191-
for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
1192-
// Check that the inner dim size is 1 for both memref/tensor type and
1193-
// vector slice. It can be folded only if they are 1 and the stride is 1.
1194-
int dim = targetType.getRank() - i - 1;
1195-
if (srcStrides[dim + rankDiff] == 1 &&
1196-
srcType.getDimSize(dim + rankDiff) == 1 &&
1197-
targetType.getDimSize(dim) == 1) {
1198-
dimsToDrop++;
1199-
} else {
1200-
break;
1201-
}
1202-
}
1250+
FailureOr<size_t> maybeDimsToDrop =
1251+
getTransferFoldableInnerUnitDims(srcType, targetType);
1252+
if (failed(maybeDimsToDrop))
1253+
return failure();
1254+
1255+
size_t dimsToDrop = maybeDimsToDrop.value();
12031256
if (dimsToDrop == 0)
12041257
return failure();
12051258

12061259
auto resultTargetVecType =
12071260
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
12081261
targetType.getElementType());
12091262

1210-
MemRefType resultMemrefType;
1211-
MemRefLayoutAttrInterface layout = srcType.getLayout();
1212-
if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
1213-
resultMemrefType = MemRefType::get(
1214-
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1215-
nullptr, srcType.getMemorySpace());
1216-
} else {
1217-
MemRefLayoutAttrInterface updatedLayout;
1218-
if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1219-
auto strides =
1220-
llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
1221-
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
1222-
strided.getOffset(), strides);
1223-
} else {
1224-
AffineMap map = srcType.getLayout().getAffineMap();
1225-
int numSymbols = map.getNumSymbols();
1226-
for (size_t i = 0; i < dimsToDrop; ++i) {
1227-
int dim = srcType.getRank() - i - 1;
1228-
map = map.replace(rewriter.getAffineDimExpr(dim),
1229-
rewriter.getAffineConstantExpr(0),
1230-
map.getNumDims() - 1, numSymbols);
1231-
}
1232-
}
1233-
resultMemrefType = MemRefType::get(
1234-
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
1235-
updatedLayout, srcType.getMemorySpace());
1236-
}
1237-
12381263
auto loc = readOp.getLoc();
1264+
MemRefType resultMemrefType =
1265+
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
12391266
SmallVector<int64_t> offsets(srcType.getRank(), 0);
12401267
SmallVector<int64_t> strides(srcType.getRank(), 1);
12411268

@@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
12611288
}
12621289
};
12631290

1291+
/// Drop inner most contiguous unit dimensions from transfer_write operand.
1292+
/// E.g.,
1293+
/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1294+
/// {in_bounds = [true, true, true, true, true]}
1295+
/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1296+
///
1297+
/// will be replaced with
1298+
///
1299+
/// %subview = memref.subview %arg0
1300+
/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1301+
/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1302+
/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1303+
/// to vector<1x16x16xf32>
1304+
/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1305+
/// {in_bounds = [true, true, true]}
1306+
/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1307+
class DropInnerMostUnitDimsTransferWrite
1308+
: public OpRewritePattern<vector::TransferWriteOp> {
1309+
using OpRewritePattern::OpRewritePattern;
1310+
1311+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1312+
PatternRewriter &rewriter) const override {
1313+
// TODO: support 0-d corner case.
1314+
if (writeOp.getTransferRank() == 0)
1315+
return failure();
1316+
1317+
// TODO: support mask.
1318+
if (writeOp.getMask())
1319+
return failure();
1320+
1321+
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1322+
if (!srcType || !srcType.hasStaticShape())
1323+
return failure();
1324+
1325+
if (!writeOp.getPermutationMap().isMinorIdentity())
1326+
return failure();
1327+
1328+
auto targetType = writeOp.getVectorType();
1329+
if (targetType.getRank() <= 1)
1330+
return failure();
1331+
1332+
FailureOr<size_t> maybeDimsToDrop =
1333+
getTransferFoldableInnerUnitDims(srcType, targetType);
1334+
if (failed(maybeDimsToDrop))
1335+
return failure();
1336+
1337+
size_t dimsToDrop = maybeDimsToDrop.value();
1338+
if (dimsToDrop == 0)
1339+
return failure();
1340+
1341+
auto resultTargetVecType =
1342+
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1343+
targetType.getElementType());
1344+
1345+
MemRefType resultMemrefType =
1346+
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
1347+
SmallVector<int64_t> offsets(srcType.getRank(), 0);
1348+
SmallVector<int64_t> strides(srcType.getRank(), 1);
1349+
ArrayAttr inBoundsAttr =
1350+
writeOp.getInBounds()
1351+
? rewriter.getArrayAttr(
1352+
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1353+
: ArrayAttr();
1354+
1355+
Location loc = writeOp.getLoc();
1356+
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1357+
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
1358+
strides);
1359+
auto permMap = getTransferMinorIdentityMap(
1360+
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1361+
1362+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1363+
loc, resultTargetVecType, writeOp.getVector());
1364+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1365+
writeOp, shapeCast, rankedReducedView,
1366+
writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1367+
// TODO: support mask.
1368+
/*mask=*/Value(), inBoundsAttr);
1369+
return success();
1370+
}
1371+
};
1372+
12641373
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
12651374
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
12661375
/// with the RHS transposed) lowering.
@@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
16961805
void mlir::vector::
16971806
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
16981807
RewritePatternSet &patterns, PatternBenefit benefit) {
1699-
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
1808+
patterns.add<DropInnerMostUnitDimsTransferRead,
1809+
DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
1810+
benefit);
17001811
}
17011812

17021813
void mlir::vector::populateSinkVectorBroadcastPatterns(

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,56 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
7676
// CHECK-NOT: memref.subview
7777
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
7878
// CHECK: return %[[READ]] : vector<4x8xf32>
79+
80+
// -----
81+
82+
func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
83+
%c0 = arith.constant 0 : index
84+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
85+
{in_bounds = [true, true, true, true, true]}
86+
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
87+
return
88+
}
89+
// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
90+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
91+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
92+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
93+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
94+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
95+
// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
96+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
97+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
98+
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
99+
100+
// -----
101+
102+
func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
103+
%c0 = arith.constant 0 : index
104+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
105+
{in_bounds = [true, true, true, true]}
106+
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
107+
return
108+
}
109+
// CHECK: func.func @drop_inner_most_dim_for_transfer_write
110+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
111+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
112+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
113+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
114+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
115+
// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
116+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
117+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
118+
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
119+
120+
// -----
121+
122+
func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
123+
%c0 = arith.constant 0 : index
124+
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
125+
{in_bounds = [true, true, true]}
126+
: vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
127+
return
128+
}
129+
// The inner most unit dims can not be dropped if the strides are not ones.
130+
// CHECK: func.func @non_unit_strides
131+
// CHECK-NOT: memref.subview

0 commit comments

Comments
 (0)