@@ -63,6 +63,7 @@ class TransferOptimization {
63
63
std::vector<Operation *> opToErase;
64
64
};
65
65
66
+ } // namespace
66
67
// / Return true if there is a path from start operation to dest operation,
67
68
// / otherwise return false. The operations have to be in the same region.
68
69
bool TransferOptimization::isReachable (Operation *start, Operation *dest) {
@@ -288,14 +289,25 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
288
289
return llvm::count_if (shape, [](int64_t dimSize) { return dimSize != 1 ; });
289
290
}
290
291
292
+ // / Returns a copy of `shape` without unit dims.
293
+ static SmallVector<int64_t > getReducedShape (ArrayRef<int64_t > shape) {
294
+ SmallVector<int64_t > reducedShape;
295
+ llvm::copy_if (shape, std::back_inserter (reducedShape),
296
+ [](int64_t dimSize) { return dimSize != 1 ; });
297
+ return reducedShape;
298
+ }
299
+
291
300
// / Returns true if all values are `arith.constant 0 : index`
292
301
static bool isZero (Value v) {
293
302
auto cst = v.getDefiningOp <arith::ConstantIndexOp>();
294
303
return cst && cst.value () == 0 ;
295
304
}
296
305
297
- // / Rewrites vector.transfer_read ops where the source has unit dims, by
298
- // / inserting a memref.subview dropping those unit dims.
306
+ namespace {
307
+
308
+ // / Rewrites `vector.transfer_read` ops where the source has unit dims, by
309
+ // / inserting a memref.subview dropping those unit dims. The vector shapes are
310
+ // / also reduced accordingly.
299
311
class TransferReadDropUnitDimsPattern
300
312
: public OpRewritePattern<vector::TransferReadOp> {
301
313
using OpRewritePattern::OpRewritePattern;
@@ -317,12 +329,15 @@ class TransferReadDropUnitDimsPattern
317
329
return failure ();
318
330
if (!transferReadOp.getPermutationMap ().isMinorIdentity ())
319
331
return failure ();
332
+ // Check if the source shape can be further reduced.
320
333
int reducedRank = getReducedRank (sourceType.getShape ());
321
334
if (reducedRank == sourceType.getRank ())
322
- return failure (); // The source shape can't be further reduced.
323
- if (reducedRank != vectorType.getRank ())
324
- return failure (); // This pattern requires the vector shape to match the
325
- // reduced source shape.
335
+ return failure ();
336
+ // Check if the reduced vector shape matches the reduced source shape.
337
+ // Otherwise, this case is not supported yet.
338
+ int vectorReducedRank = getReducedRank (vectorType.getShape ());
339
+ if (reducedRank != vectorReducedRank)
340
+ return failure ();
326
341
if (llvm::any_of (transferReadOp.getIndices (),
327
342
[](Value v) { return !isZero (v); }))
328
343
return failure ();
@@ -331,14 +346,22 @@ class TransferReadDropUnitDimsPattern
331
346
Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
332
347
SmallVector<Value> zeros (reducedRank, c0);
333
348
auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
334
- rewriter.replaceOpWithNewOp <vector::TransferReadOp>(
335
- transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
349
+ auto reducedVectorType = VectorType::get (
350
+ getReducedShape (vectorType.getShape ()), vectorType.getElementType ());
351
+
352
+ auto newTransferReadOp = rewriter.create <vector::TransferReadOp>(
353
+ loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
354
+ auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
355
+ loc, vectorType, newTransferReadOp);
356
+ rewriter.replaceOp (transferReadOp, shapeCast);
357
+
336
358
return success ();
337
359
}
338
360
};
339
361
340
- // / Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
341
- // / unit dims, by inserting a memref.subview dropping those unit dims.
362
+ // / Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
363
+ // / has unit dims, by inserting a `memref.subview` dropping those unit dims. The
364
+ // / vector shapes are also reduced accordingly.
342
365
class TransferWriteDropUnitDimsPattern
343
366
: public OpRewritePattern<vector::TransferWriteOp> {
344
367
using OpRewritePattern::OpRewritePattern;
@@ -360,12 +383,15 @@ class TransferWriteDropUnitDimsPattern
360
383
return failure ();
361
384
if (!transferWriteOp.getPermutationMap ().isMinorIdentity ())
362
385
return failure ();
386
+ // Check if the destination shape can be further reduced.
363
387
int reducedRank = getReducedRank (sourceType.getShape ());
364
388
if (reducedRank == sourceType.getRank ())
365
- return failure (); // The source shape can't be further reduced.
366
- if (reducedRank != vectorType.getRank ())
367
- return failure (); // This pattern requires the vector shape to match the
368
- // reduced source shape.
389
+ return failure ();
390
+ // Check if the reduced vector shape matches the reduced destination shape.
391
+ // Otherwise, this case is not supported yet.
392
+ int vectorReducedRank = getReducedRank (vectorType.getShape ());
393
+ if (reducedRank != vectorReducedRank)
394
+ return failure ();
369
395
if (llvm::any_of (transferWriteOp.getIndices (),
370
396
[](Value v) { return !isZero (v); }))
371
397
return failure ();
@@ -374,12 +400,20 @@ class TransferWriteDropUnitDimsPattern
374
400
Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
375
401
SmallVector<Value> zeros (reducedRank, c0);
376
402
auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
403
+ VectorType reducedVectorType = VectorType::get (
404
+ getReducedShape (vectorType.getShape ()), vectorType.getElementType ());
405
+
406
+ auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
407
+ loc, reducedVectorType, vector);
377
408
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
378
- transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
409
+ transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
410
+
379
411
return success ();
380
412
}
381
413
};
382
414
415
+ } // namespace
416
+
383
417
// / Return true if the memref type has its inner dimension matching the given
384
418
// / shape. Otherwise return false.
385
419
static int64_t hasMatchingInnerContigousShape (MemRefType memrefType,
@@ -439,6 +473,8 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
439
473
return success ();
440
474
}
441
475
476
+ namespace {
477
+
442
478
// / Rewrites contiguous row-major vector.transfer_read ops by inserting
443
479
// / memref.collapse_shape on the source so that the resulting
444
480
// / vector.transfer_read has a 1D source. Requires the source shape to be
@@ -732,6 +768,7 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
732
768
return success ();
733
769
}
734
770
};
771
+
735
772
} // namespace
736
773
737
774
void mlir::vector::transferOpflowOpt (RewriterBase &rewriter,
0 commit comments