@@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1152
1152
}
1153
1153
};
1154
1154
1155
- // Drop inner most contiguous unit dimensions from transfer_read operand.
1156
- class DropInnerMostUnitDims : public OpRewritePattern <vector::TransferReadOp> {
1155
+ // / Returns the number of dims can be folded away from transfer ops. It returns
1156
+ // / a failure if it can not determine the number of dims to be folded.
1157
+ // / Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1158
+ // / `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1159
+ // / can be dropped by memref.subview ops.
1160
+ // / Example 2: it returns "1" if `srcType` is the same memref type with
1161
+ // / [8192, 16, 8, 1] strides.
1162
+ static FailureOr<size_t >
1163
+ getTransferFoldableInnerUnitDims (MemRefType srcType, VectorType vectorType) {
1164
+ SmallVector<int64_t > srcStrides;
1165
+ int64_t srcOffset;
1166
+ if (failed (getStridesAndOffset (srcType, srcStrides, srcOffset)))
1167
+ return failure ();
1168
+
1169
+ // According to vector.transfer_read/write semantics, the vector can be a
1170
+ // slice. Thus, we have to offset the check index with `rankDiff` in
1171
+ // `srcStrides` and source dim sizes.
1172
+ size_t result = 0 ;
1173
+ int rankDiff = srcType.getRank () - vectorType.getRank ();
1174
+ for (int64_t i = 0 , e = vectorType.getRank (); i < e; ++i) {
1175
+ // Check that the inner dim size is 1 for both memref type and vector slice.
1176
+ // It can be folded only if they are 1 and the stride is 1.
1177
+ int dim = vectorType.getRank () - i - 1 ;
1178
+ if (srcStrides[dim + rankDiff] != 1 ||
1179
+ srcType.getDimSize (dim + rankDiff) != 1 ||
1180
+ vectorType.getDimSize (dim) != 1 )
1181
+ break ;
1182
+ result++;
1183
+ }
1184
+ return result;
1185
+ }
1186
+
1187
+ // / Returns a MemRef type that drops inner `dimsToDrop` dimensions from
1188
+ // / `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
1189
+ // / two, it returns memref<512x16x16> type.
1190
+ static MemRefType getMemRefTypeWithDroppingInnerDims (OpBuilder &builder,
1191
+ MemRefType srcType,
1192
+ size_t dimsToDrop) {
1193
+ MemRefType resultMemrefType;
1194
+ MemRefLayoutAttrInterface layout = srcType.getLayout ();
1195
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity ()) {
1196
+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1197
+ srcType.getElementType (), nullptr ,
1198
+ srcType.getMemorySpace ());
1199
+ }
1200
+ MemRefLayoutAttrInterface updatedLayout;
1201
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1202
+ auto strides = llvm::to_vector (strided.getStrides ().drop_back (dimsToDrop));
1203
+ updatedLayout = StridedLayoutAttr::get (strided.getContext (),
1204
+ strided.getOffset (), strides);
1205
+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1206
+ srcType.getElementType (), updatedLayout,
1207
+ srcType.getMemorySpace ());
1208
+ }
1209
+
1210
+ // Non-strided layout case.
1211
+ AffineMap map = srcType.getLayout ().getAffineMap ();
1212
+ int numSymbols = map.getNumSymbols ();
1213
+ for (size_t i = 0 ; i < dimsToDrop; ++i) {
1214
+ int dim = srcType.getRank () - i - 1 ;
1215
+ map = map.replace (builder.getAffineDimExpr (dim),
1216
+ builder.getAffineConstantExpr (0 ), map.getNumDims () - 1 ,
1217
+ numSymbols);
1218
+ }
1219
+ return MemRefType::get (srcType.getShape ().drop_back (dimsToDrop),
1220
+ srcType.getElementType (), updatedLayout,
1221
+ srcType.getMemorySpace ());
1222
+ }
1223
+
1224
+ // / Drop inner most contiguous unit dimensions from transfer_read operand.
1225
+ class DropInnerMostUnitDimsTransferRead
1226
+ : public OpRewritePattern<vector::TransferReadOp> {
1157
1227
using OpRewritePattern::OpRewritePattern;
1158
1228
1159
1229
LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
@@ -1177,65 +1247,22 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
1177
1247
if (targetType.getRank () <= 1 )
1178
1248
return failure ();
1179
1249
1180
- SmallVector<int64_t > srcStrides;
1181
- int64_t srcOffset;
1182
- if (failed (getStridesAndOffset (srcType, srcStrides, srcOffset)))
1183
- return failure ();
1184
-
1185
- // According to vector.transfer_read semantics, the result can be a slice.
1186
- // It pads the indices with `1` starting from beginning. Thus, we have to
1187
- // offset the check index with `rankDiff` in `srcStrides` and source dim
1188
- // sizes.
1189
- size_t dimsToDrop = 0 ;
1190
- int rankDiff = srcType.getRank () - targetType.getRank ();
1191
- for (int64_t i = 0 , e = targetType.getRank (); i < e; ++i) {
1192
- // Check that the inner dim size is 1 for both memref/tensor type and
1193
- // vector slice. It can be folded only if they are 1 and the stride is 1.
1194
- int dim = targetType.getRank () - i - 1 ;
1195
- if (srcStrides[dim + rankDiff] == 1 &&
1196
- srcType.getDimSize (dim + rankDiff) == 1 &&
1197
- targetType.getDimSize (dim) == 1 ) {
1198
- dimsToDrop++;
1199
- } else {
1200
- break ;
1201
- }
1202
- }
1250
+ FailureOr<size_t > maybeDimsToDrop =
1251
+ getTransferFoldableInnerUnitDims (srcType, targetType);
1252
+ if (failed (maybeDimsToDrop))
1253
+ return failure ();
1254
+
1255
+ size_t dimsToDrop = maybeDimsToDrop.value ();
1203
1256
if (dimsToDrop == 0 )
1204
1257
return failure ();
1205
1258
1206
1259
auto resultTargetVecType =
1207
1260
VectorType::get (targetType.getShape ().drop_back (dimsToDrop),
1208
1261
targetType.getElementType ());
1209
1262
1210
- MemRefType resultMemrefType;
1211
- MemRefLayoutAttrInterface layout = srcType.getLayout ();
1212
- if (isa<AffineMapAttr>(layout) && layout.isIdentity ()) {
1213
- resultMemrefType = MemRefType::get (
1214
- srcType.getShape ().drop_back (dimsToDrop), srcType.getElementType (),
1215
- nullptr , srcType.getMemorySpace ());
1216
- } else {
1217
- MemRefLayoutAttrInterface updatedLayout;
1218
- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
1219
- auto strides =
1220
- llvm::to_vector (strided.getStrides ().drop_back (dimsToDrop));
1221
- updatedLayout = StridedLayoutAttr::get (strided.getContext (),
1222
- strided.getOffset (), strides);
1223
- } else {
1224
- AffineMap map = srcType.getLayout ().getAffineMap ();
1225
- int numSymbols = map.getNumSymbols ();
1226
- for (size_t i = 0 ; i < dimsToDrop; ++i) {
1227
- int dim = srcType.getRank () - i - 1 ;
1228
- map = map.replace (rewriter.getAffineDimExpr (dim),
1229
- rewriter.getAffineConstantExpr (0 ),
1230
- map.getNumDims () - 1 , numSymbols);
1231
- }
1232
- }
1233
- resultMemrefType = MemRefType::get (
1234
- srcType.getShape ().drop_back (dimsToDrop), srcType.getElementType (),
1235
- updatedLayout, srcType.getMemorySpace ());
1236
- }
1237
-
1238
1263
auto loc = readOp.getLoc ();
1264
+ MemRefType resultMemrefType =
1265
+ getMemRefTypeWithDroppingInnerDims (rewriter, srcType, dimsToDrop);
1239
1266
SmallVector<int64_t > offsets (srcType.getRank (), 0 );
1240
1267
SmallVector<int64_t > strides (srcType.getRank (), 1 );
1241
1268
@@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
1261
1288
}
1262
1289
};
1263
1290
1291
+ // / Drop inner most contiguous unit dimensions from transfer_write operand.
1292
+ // / E.g.,
1293
+ // / vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1294
+ // / {in_bounds = [true, true, true, true, true]}
1295
+ // / : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1296
+ // /
1297
+ // / will be replaced with
1298
+ // /
1299
+ // / %subview = memref.subview %arg0
1300
+ // / [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1301
+ // / : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1302
+ // / %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1303
+ // / to vector<1x16x16xf32>
1304
+ // / vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1305
+ // / {in_bounds = [true, true, true]}
1306
+ // / : vector<1x16x16xf32>, memref<1x512x16xf32>
1307
+ class DropInnerMostUnitDimsTransferWrite
1308
+ : public OpRewritePattern<vector::TransferWriteOp> {
1309
+ using OpRewritePattern::OpRewritePattern;
1310
+
1311
+ LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
1312
+ PatternRewriter &rewriter) const override {
1313
+ // TODO: support 0-d corner case.
1314
+ if (writeOp.getTransferRank () == 0 )
1315
+ return failure ();
1316
+
1317
+ // TODO: support mask.
1318
+ if (writeOp.getMask ())
1319
+ return failure ();
1320
+
1321
+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource ().getType ());
1322
+ if (!srcType || !srcType.hasStaticShape ())
1323
+ return failure ();
1324
+
1325
+ if (!writeOp.getPermutationMap ().isMinorIdentity ())
1326
+ return failure ();
1327
+
1328
+ auto targetType = writeOp.getVectorType ();
1329
+ if (targetType.getRank () <= 1 )
1330
+ return failure ();
1331
+
1332
+ FailureOr<size_t > maybeDimsToDrop =
1333
+ getTransferFoldableInnerUnitDims (srcType, targetType);
1334
+ if (failed (maybeDimsToDrop))
1335
+ return failure ();
1336
+
1337
+ size_t dimsToDrop = maybeDimsToDrop.value ();
1338
+ if (dimsToDrop == 0 )
1339
+ return failure ();
1340
+
1341
+ auto resultTargetVecType =
1342
+ VectorType::get (targetType.getShape ().drop_back (dimsToDrop),
1343
+ targetType.getElementType ());
1344
+
1345
+ MemRefType resultMemrefType =
1346
+ getMemRefTypeWithDroppingInnerDims (rewriter, srcType, dimsToDrop);
1347
+ SmallVector<int64_t > offsets (srcType.getRank (), 0 );
1348
+ SmallVector<int64_t > strides (srcType.getRank (), 1 );
1349
+ ArrayAttr inBoundsAttr =
1350
+ writeOp.getInBounds ()
1351
+ ? rewriter.getArrayAttr (
1352
+ writeOp.getInBoundsAttr ().getValue ().drop_back (dimsToDrop))
1353
+ : ArrayAttr ();
1354
+
1355
+ Location loc = writeOp.getLoc ();
1356
+ Value rankedReducedView = rewriter.create <memref::SubViewOp>(
1357
+ loc, resultMemrefType, writeOp.getSource (), offsets, srcType.getShape (),
1358
+ strides);
1359
+ auto permMap = getTransferMinorIdentityMap (
1360
+ cast<ShapedType>(rankedReducedView.getType ()), resultTargetVecType);
1361
+
1362
+ auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
1363
+ loc, resultTargetVecType, writeOp.getVector ());
1364
+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
1365
+ writeOp, shapeCast, rankedReducedView,
1366
+ writeOp.getIndices ().drop_back (dimsToDrop), AffineMapAttr::get (permMap),
1367
+ // TODO: support mask.
1368
+ /* mask=*/ Value (), inBoundsAttr);
1369
+ return success ();
1370
+ }
1371
+ };
1372
+
1264
1373
// / Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1265
1374
// / semantics to a contraction suitable for MMT (matrix matrix multiplication
1266
1375
// / with the RHS transposed) lowering.
@@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
1696
1805
void mlir::vector::
1697
1806
populateVectorTransferCollapseInnerMostContiguousDimsPatterns (
1698
1807
RewritePatternSet &patterns, PatternBenefit benefit) {
1699
- patterns.add <DropInnerMostUnitDims>(patterns.getContext (), benefit);
1808
+ patterns.add <DropInnerMostUnitDimsTransferRead,
1809
+ DropInnerMostUnitDimsTransferWrite>(patterns.getContext (),
1810
+ benefit);
1700
1811
}
1701
1812
1702
1813
void mlir::vector::populateSinkVectorBroadcastPatterns (
0 commit comments