@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
260
260
opToErase.push_back (read.getOperation ());
261
261
}
262
262
263
- // / Returns a copy of `shape` without unit dims.
264
- static SmallVector<int64_t > getReducedShape (ArrayRef<int64_t > shape) {
265
- SmallVector<int64_t > reducedShape;
266
- llvm::copy_if (shape, std::back_inserter (reducedShape),
267
- [](int64_t dimSize) { return dimSize != 1 ; });
268
- return reducedShape;
269
- }
270
-
271
263
// / Converts OpFoldResults to int64_t shape without unit dims.
272
264
static SmallVector<int64_t > getReducedShape (ArrayRef<OpFoldResult> mixedSizes) {
273
265
SmallVector<int64_t > reducedShape;
@@ -340,7 +332,7 @@ static FailureOr<Value>
340
332
createMaskDropNonScalableUnitDims (PatternRewriter &rewriter, Location loc,
341
333
vector::CreateMaskOp op) {
342
334
auto type = op.getType ();
343
- auto reducedType = trimNonScalableUnitDims (type);
335
+ VectorType reducedType = trimNonScalableUnitDims (type);
344
336
if (reducedType.getRank () == type.getRank ())
345
337
return failure ();
346
338
@@ -391,7 +383,7 @@ class TransferReadDropUnitDimsPattern
391
383
return failure ();
392
384
// Check if the reduced vector shape matches the reduced source shape.
393
385
// Otherwise, this case is not supported yet.
394
- auto reducedVectorType = trimNonScalableUnitDims (vectorType);
386
+ VectorType reducedVectorType = trimNonScalableUnitDims (vectorType);
395
387
if (reducedRank != reducedVectorType.getRank ())
396
388
return failure ();
397
389
if (llvm::any_of (transferReadOp.getIndices (), [](Value v) {
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
446
438
Value source = transferWriteOp.getSource ();
447
439
MemRefType sourceType = dyn_cast<MemRefType>(source.getType ());
448
440
// TODO: support tensor type.
449
- if (!sourceType || !sourceType.hasStaticShape ())
450
- return failure ();
451
- if (sourceType.getNumElements () != vectorType.getNumElements ())
441
+ if (!sourceType)
452
442
return failure ();
453
443
// TODO: generalize this pattern, relax the requirements here.
454
444
if (transferWriteOp.hasOutOfBoundsDim ())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
461
451
return failure ();
462
452
// Check if the reduced vector shape matches the reduced destination shape.
463
453
// Otherwise, this case is not supported yet.
464
- int vectorReducedRank = getReducedRank (vectorType. getShape () );
465
- if (reducedRank != vectorReducedRank )
454
+ VectorType reducedVectorType = trimNonScalableUnitDims (vectorType);
455
+ if (reducedRank != reducedVectorType. getRank () )
466
456
return failure ();
467
457
if (llvm::any_of (transferWriteOp.getIndices (), [](Value v) {
468
458
return getConstantIntValue (v) != static_cast <int64_t >(0 );
469
459
}))
470
460
return failure ();
461
+
462
+ Value maskOp = transferWriteOp.getMask ();
463
+ if (maskOp) {
464
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
465
+ if (!createMaskOp)
466
+ return rewriter.notifyMatchFailure (
467
+ transferWriteOp,
468
+ " unsupported mask op, only 'vector.create_mask' is "
469
+ " currently supported" );
470
+ FailureOr<Value> rankReducedCreateMask =
471
+ createMaskDropNonScalableUnitDims (rewriter, loc, createMaskOp);
472
+ if (failed (rankReducedCreateMask))
473
+ return failure ();
474
+ maskOp = *rankReducedCreateMask;
475
+ }
471
476
Value reducedShapeSource =
472
477
rankReducingSubviewDroppingUnitDims (rewriter, loc, source);
473
478
Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
474
479
SmallVector<Value> zeros (reducedRank, c0);
475
480
auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
476
- VectorType reducedVectorType = VectorType::get (
477
- getReducedShape (vectorType.getShape ()), vectorType.getElementType ());
478
-
481
+ SmallVector<bool > inBounds (reducedVectorType.getRank (), true );
479
482
auto shapeCast = rewriter.createOrFold <vector::ShapeCastOp>(
480
483
loc, reducedVectorType, vector);
481
484
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
482
- transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
485
+ transferWriteOp, Type (), shapeCast, reducedShapeSource, zeros,
486
+ identityMap, maskOp, rewriter.getBoolArrayAttr (inBounds));
483
487
484
488
return success ();
485
489
}
0 commit comments