@@ -224,10 +224,10 @@ struct VectorizationState {
224
224
// / Masks an operation with the canonical vector mask if the operation needs
225
225
// / masking. Returns the masked operation or the original operation if masking
226
226
// / is not needed. If provided, the canonical mask for this operation is
227
- // / permuted using `maybeMaskingMap `.
227
+ // / permuted using `maybeIndexingMap `.
228
228
Operation *
229
229
maskOperation (RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230
- std::optional<AffineMap> maybeMaskingMap = std::nullopt);
230
+ std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231
231
232
232
private:
233
233
// / Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422
422
return mask;
423
423
}
424
424
425
- // / Masks an operation with the canonical vector mask if the operation needs
426
- // / masking. Returns the masked operation or the original operation if masking
427
- // / is not needed. If provided, the canonical mask for this operation is
428
- // / permuted using `maybeMaskingMap`.
429
425
Operation *
430
426
VectorizationState::maskOperation (RewriterBase &rewriter, Operation *opToMask,
431
427
LinalgOp linalgOp,
432
- std::optional<AffineMap> maybeMaskingMap ) {
428
+ std::optional<AffineMap> maybeIndexingMap ) {
433
429
LDBG (" Trying to mask: " << *opToMask << " \n " );
434
430
431
+ std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432
+ // The Operand indexing map may contain "zero" results, e.g.:
433
+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434
+ // When applied to canonical vector shapes like these:
435
+ // (1, 16, 16, 4)
436
+ // we would get:
437
+ // (1, 16, 16, 0)
438
+ // Instead, we should extract the following map permutation map for masking:
439
+ // (d0, d1, d2, d3) -> (d0, d1, d2)
440
+ // This way, the corresponding vector/mask type will be:
441
+ // vector<1x16x16xty>
442
+ // rather than:
443
+ // vector<1x16x16x0xty>
444
+ if (maybeIndexingMap)
445
+ maybeMaskingMap = maybeIndexingMap->dropZeroResults ();
446
+
435
447
// Create or retrieve mask for this operation.
436
448
Value mask =
437
449
getOrCreateMaskFor (rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -640,21 +652,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
640
652
loc, value, outputOperand->get (), ValueRange{});
641
653
}
642
654
643
- // The operand map may contain "zero" results, e.g.:
644
- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
645
- // When applied to canonical vector shapes like these:
646
- // (1, 16, 16, 4)
647
- // we would get:
648
- // (1, 16, 16, 0)
649
- // Instead, we should extract the following map:
650
- // (d0, d1, d2, d3) -> (d0, d1, d2)
651
- // This way, the corresponding vector/mask type will be:
652
- // vector<1x16x16xty>
653
- // rather than:
654
- // vector<1x16x16x0xty>
655
- AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults ();
656
- write =
657
- state.maskOperation (rewriter, write, linalgOp, opOperantMapWithoutZeros);
655
+ write = state.maskOperation (rewriter, write, linalgOp, opOperandMap);
658
656
659
657
// If masked, set in-bounds to true. Masking guarantees that the access will
660
658
// be in-bounds.
@@ -1332,16 +1330,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1332
1330
// permutation map and masking map.
1333
1331
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
1334
1332
1335
- // Remove zeros from indexing map to use it as masking map.
1336
- SmallVector<int64_t > zeroPos;
1337
- auto results = indexingMap.getResults ();
1338
- for (const auto &result : llvm::enumerate (results)) {
1339
- if (isa<AffineConstantExpr>(result.value ())) {
1340
- zeroPos.push_back (result.index ());
1341
- }
1342
- }
1343
- AffineMap maskingMap = indexingMap.dropResults (zeroPos);
1344
-
1345
1333
AffineMap readMap;
1346
1334
VectorType readType;
1347
1335
Type elemType = getElementTypeOrSelf (opOperand->get ());
@@ -1371,7 +1359,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1371
1359
Operation *read = rewriter.create <vector::TransferReadOp>(
1372
1360
loc, readType, opOperand->get (), indices, readMap,
1373
1361
ArrayRef<bool >(inBounds));
1374
- read = state.maskOperation (rewriter, read, linalgOp, maskingMap );
1362
+ read = state.maskOperation (rewriter, read, linalgOp, indexingMap );
1375
1363
Value readValue = read->getResult (0 );
1376
1364
1377
1365
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
0 commit comments