@@ -63,7 +63,6 @@ class TransferOptimization {
63
63
std::vector<Operation *> opToErase;
64
64
};
65
65
66
- } // namespace
67
66
// / Return true if there is a path from start operation to dest operation,
68
67
// / otherwise return false. The operations have to be in the same region.
69
68
bool TransferOptimization::isReachable (Operation *start, Operation *dest) {
@@ -289,25 +288,14 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
289
288
return llvm::count_if (shape, [](int64_t dimSize) { return dimSize != 1 ; });
290
289
}
291
290
292
- // / Returns a copy of `shape` without unit dims.
293
- static SmallVector<int64_t > getReducedShape (ArrayRef<int64_t > shape) {
294
- SmallVector<int64_t > reducedShape;
295
- llvm::copy_if (shape, std::back_inserter (reducedShape),
296
- [](int64_t dimSize) { return dimSize != 1 ; });
297
- return reducedShape;
298
- }
299
-
300
291
// / Returns true if all values are `arith.constant 0 : index`
301
292
static bool isZero (Value v) {
302
293
auto cst = v.getDefiningOp <arith::ConstantIndexOp>();
303
294
return cst && cst.value () == 0 ;
304
295
}
305
296
306
- namespace {
307
-
308
- // / Rewrites `vector.transfer_read` ops where the source has unit dims, by
309
- // / inserting a memref.subview dropping those unit dims. The vector shapes are
310
- // / also reduced accordingly.
297
+ // / Rewrites vector.transfer_read ops where the source has unit dims, by
298
+ // / inserting a memref.subview dropping those unit dims.
311
299
class TransferReadDropUnitDimsPattern
312
300
: public OpRewritePattern<vector::TransferReadOp> {
313
301
using OpRewritePattern::OpRewritePattern;
@@ -329,15 +317,12 @@ class TransferReadDropUnitDimsPattern
329
317
return failure ();
330
318
if (!transferReadOp.getPermutationMap ().isMinorIdentity ())
331
319
return failure ();
332
- // Check if the source shape can be further reduced.
333
320
int reducedRank = getReducedRank (sourceType.getShape ());
334
321
if (reducedRank == sourceType.getRank ())
335
- return failure ();
336
- // Check if the reduced vector shape matches the reduced source shape.
337
- // Otherwise, this case is not supported yet.
338
- int vectorReducedRank = getReducedRank (vectorType.getShape ());
339
- if (reducedRank != vectorReducedRank)
340
- return failure ();
322
+ return failure (); // The source shape can't be further reduced.
323
+ if (reducedRank != vectorType.getRank ())
324
+ return failure (); // This pattern requires the vector shape to match the
325
+ // reduced source shape.
341
326
if (llvm::any_of (transferReadOp.getIndices (),
342
327
[](Value v) { return !isZero (v); }))
343
328
return failure ();
@@ -346,22 +331,14 @@ class TransferReadDropUnitDimsPattern
346
331
Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
347
332
SmallVector<Value> zeros (reducedRank, c0);
348
333
auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
349
- auto reducedVectorType = VectorType::get (
350
- getReducedShape (vectorType.getShape ()), vectorType.getElementType ());
351
-
352
- auto newTransferReadOp = rewriter.create <vector::TransferReadOp>(
353
- loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
354
- auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
355
- loc, vectorType, newTransferReadOp);
356
- rewriter.replaceOp (transferReadOp, shapeCast);
357
-
334
+ rewriter.replaceOpWithNewOp <vector::TransferReadOp>(
335
+ transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
358
336
return success ();
359
337
}
360
338
};
361
339
362
- // / Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
363
- // / has unit dims, by inserting a `memref.subview` dropping those unit dims. The
364
- // / vector shapes are also reduced accordingly.
340
+ // / Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
341
+ // / unit dims, by inserting a memref.subview dropping those unit dims.
365
342
class TransferWriteDropUnitDimsPattern
366
343
: public OpRewritePattern<vector::TransferWriteOp> {
367
344
using OpRewritePattern::OpRewritePattern;
@@ -383,15 +360,12 @@ class TransferWriteDropUnitDimsPattern
383
360
return failure ();
384
361
if (!transferWriteOp.getPermutationMap ().isMinorIdentity ())
385
362
return failure ();
386
- // Check if the destination shape can be further reduced.
387
363
int reducedRank = getReducedRank (sourceType.getShape ());
388
364
if (reducedRank == sourceType.getRank ())
389
- return failure ();
390
- // Check if the reduced vector shape matches the reduced destination shape.
391
- // Otherwise, this case is not supported yet.
392
- int vectorReducedRank = getReducedRank (vectorType.getShape ());
393
- if (reducedRank != vectorReducedRank)
394
- return failure ();
365
+ return failure (); // The source shape can't be further reduced.
366
+ if (reducedRank != vectorType.getRank ())
367
+ return failure (); // This pattern requires the vector shape to match the
368
+ // reduced source shape.
395
369
if (llvm::any_of (transferWriteOp.getIndices (),
396
370
[](Value v) { return !isZero (v); }))
397
371
return failure ();
@@ -400,20 +374,12 @@ class TransferWriteDropUnitDimsPattern
400
374
Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
401
375
SmallVector<Value> zeros (reducedRank, c0);
402
376
auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
403
- VectorType reducedVectorType = VectorType::get (
404
- getReducedShape (vectorType.getShape ()), vectorType.getElementType ());
405
-
406
- auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
407
- loc, reducedVectorType, vector);
408
377
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
409
- transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
410
-
378
+ transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
411
379
return success ();
412
380
}
413
381
};
414
382
415
- } // namespace
416
-
417
383
// / Return true if the memref type has its inner dimension matching the given
418
384
// / shape. Otherwise return false.
419
385
static int64_t hasMatchingInnerContigousShape (MemRefType memrefType,
@@ -473,8 +439,6 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
473
439
return success ();
474
440
}
475
441
476
- namespace {
477
-
478
442
// / Rewrites contiguous row-major vector.transfer_read ops by inserting
479
443
// / memref.collapse_shape on the source so that the resulting
480
444
// / vector.transfer_read has a 1D source. Requires the source shape to be
@@ -768,7 +732,6 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
768
732
return success ();
769
733
}
770
734
};
771
-
772
735
} // namespace
773
736
774
737
void mlir::vector::transferOpflowOpt (RewriterBase &rewriter,
0 commit comments