@@ -224,10 +224,10 @@ struct VectorizationState {
224
224
// / Masks an operation with the canonical vector mask if the operation needs
225
225
// / masking. Returns the masked operation or the original operation if masking
226
226
// / is not needed. If provided, the canonical mask for this operation is
227
- // / permuted using `maybeMaskingMap `.
227
+ // / permuted using `maybeIndexingMap `.
228
228
Operation *
229
229
maskOperation (RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230
- std::optional<AffineMap> maybeMaskingMap = std::nullopt);
230
+ std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231
231
232
232
private:
233
233
// / Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422
422
return mask;
423
423
}
424
424
425
- // / Masks an operation with the canonical vector mask if the operation needs
426
- // / masking. Returns the masked operation or the original operation if masking
427
- // / is not needed. If provided, the canonical mask for this operation is
428
- // / permuted using `maybeMaskingMap`.
429
425
Operation *
430
426
VectorizationState::maskOperation (RewriterBase &rewriter, Operation *opToMask,
431
427
LinalgOp linalgOp,
432
- std::optional<AffineMap> maybeMaskingMap ) {
428
+ std::optional<AffineMap> maybeIndexingMap ) {
433
429
LDBG (" Trying to mask: " << *opToMask << " \n " );
434
430
431
+ std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432
+ // The Operand indexing map may contain "zero" results, e.g.:
433
+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434
+ // When applied to canonical vector shapes like these:
435
+ // (1, 16, 16, 4)
436
+ // we would get:
437
+ // (1, 16, 16, 0)
438
+ // Instead, we should extract the following map permutation map for masking:
439
+ // (d0, d1, d2, d3) -> (d0, d1, d2)
440
+ // This way, the corresponding vector/mask type will be:
441
+ // vector<1x16x16xty>
442
+ // rather than:
443
+ // vector<1x16x16x0xty>
444
+ if (maybeIndexingMap)
445
+ maybeMaskingMap = maybeIndexingMap->dropZeroResults ();
446
+
435
447
// Create or retrieve mask for this operation.
436
448
Value mask =
437
449
getOrCreateMaskFor (rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -630,21 +642,8 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
630
642
loc, value, outputOperand->get (), ValueRange{});
631
643
}
632
644
633
- // The operand map may contain "zero" results, e.g.:
634
- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
635
- // When applied to canonical vector shapes like these:
636
- // (1, 16, 16, 4)
637
- // we would get:
638
- // (1, 16, 16, 0)
639
- // Instead, we should extract the following map:
640
- // (d0, d1, d2, d3) -> (d0, d1, d2)
641
- // This way, the corresponding vector/mask type will be:
642
- // vector<1x16x16xty>
643
- // rather than:
644
- // vector<1x16x16x0xty>
645
- AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults ();
646
645
write =
647
- state.maskOperation (rewriter, write, linalgOp, opOperantMapWithoutZeros );
646
+ state.maskOperation (rewriter, write, linalgOp, opOperandMap );
648
647
649
648
// If masked, set in-bounds to true. Masking guarantees that the access will
650
649
// be in-bounds.
@@ -1330,16 +1329,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1330
1329
// permutation map and masking map.
1331
1330
AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
1332
1331
1333
- // Remove zeros from indexing map to use it as masking map.
1334
- SmallVector<int64_t > zeroPos;
1335
- auto results = indexingMap.getResults ();
1336
- for (const auto &result : llvm::enumerate (results)) {
1337
- if (isa<AffineConstantExpr>(result.value ())) {
1338
- zeroPos.push_back (result.index ());
1339
- }
1340
- }
1341
- AffineMap maskingMap = indexingMap.dropResults (zeroPos);
1342
-
1343
1332
AffineMap readMap;
1344
1333
VectorType readType;
1345
1334
Type elemType = getElementTypeOrSelf (opOperand->get ());
@@ -1369,7 +1358,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1369
1358
Operation *read = rewriter.create <vector::TransferReadOp>(
1370
1359
loc, readType, opOperand->get (), indices, readMap,
1371
1360
ArrayRef<bool >(inBounds));
1372
- read = state.maskOperation (rewriter, read, linalgOp, maskingMap );
1361
+ read = state.maskOperation (rewriter, read, linalgOp, indexingMap );
1373
1362
Value readValue = read->getResult (0 );
1374
1363
1375
1364
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
0 commit comments