Skip to content

Commit 8dbb4ce

Browse files
committed
fixup! [mlir][Linalg] Refine how broadcast dims are treated
* Move the logic to remove zero from indexing maps to `maskOperation` * Update the input mask name in `maskOperation` to `maybeIndexingMap` - the actual input is always an indexing map extracted from the corresponding linalg Op * Remove the duplicated comment for `maskOperation`
1 parent 9c5c154 commit 8dbb4ce

File tree

1 file changed

+21
-32
lines changed

1 file changed

+21
-32
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ struct VectorizationState {
224224
/// Masks an operation with the canonical vector mask if the operation needs
225225
/// masking. Returns the masked operation or the original operation if masking
226226
/// is not needed. If provided, the canonical mask for this operation is
227-
/// permuted using `maybeMaskingMap`.
227+
/// permuted using `maybeIndexingMap`.
228228
Operation *
229229
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230-
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
230+
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231231

232232
private:
233233
/// Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422422
return mask;
423423
}
424424

425-
/// Masks an operation with the canonical vector mask if the operation needs
426-
/// masking. Returns the masked operation or the original operation if masking
427-
/// is not needed. If provided, the canonical mask for this operation is
428-
/// permuted using `maybeMaskingMap`.
429425
Operation *
430426
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
431427
LinalgOp linalgOp,
432-
std::optional<AffineMap> maybeMaskingMap) {
428+
std::optional<AffineMap> maybeIndexingMap) {
433429
LDBG("Trying to mask: " << *opToMask << "\n");
434430

431+
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432+
// The Operand indexing map may contain "zero" results, e.g.:
433+
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434+
// When applied to canonical vector shapes like these:
435+
// (1, 16, 16, 4)
436+
// we would get:
437+
// (1, 16, 16, 0)
438+
// Instead, we should extract the following map permutation map for masking:
439+
// (d0, d1, d2, d3) -> (d0, d1, d2)
440+
// This way, the corresponding vector/mask type will be:
441+
// vector<1x16x16xty>
442+
// rather than:
443+
// vector<1x16x16x0xty>
444+
if (maybeIndexingMap)
445+
maybeMaskingMap = maybeIndexingMap->dropZeroResults();
446+
435447
// Create or retrieve mask for this operation.
436448
Value mask =
437449
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -630,21 +642,8 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
630642
loc, value, outputOperand->get(), ValueRange{});
631643
}
632644

633-
// The operand map may contain "zero" results, e.g.:
634-
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
635-
// When applied to canonical vector shapes like these:
636-
// (1, 16, 16, 4)
637-
// we would get:
638-
// (1, 16, 16, 0)
639-
// Instead, we should extract the following map:
640-
// (d0, d1, d2, d3) -> (d0, d1, d2)
641-
// This way, the corresponding vector/mask type will be:
642-
// vector<1x16x16xty>
643-
// rather than:
644-
// vector<1x16x16x0xty>
645-
AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults();
646645
write =
647-
state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
646+
state.maskOperation(rewriter, write, linalgOp, opOperandMap);
648647

649648
// If masked, set in-bounds to true. Masking guarantees that the access will
650649
// be in-bounds.
@@ -1330,16 +1329,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13301329
// permutation map and masking map.
13311330
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
13321331

1333-
// Remove zeros from indexing map to use it as masking map.
1334-
SmallVector<int64_t> zeroPos;
1335-
auto results = indexingMap.getResults();
1336-
for (const auto &result : llvm::enumerate(results)) {
1337-
if (isa<AffineConstantExpr>(result.value())) {
1338-
zeroPos.push_back(result.index());
1339-
}
1340-
}
1341-
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1342-
13431332
AffineMap readMap;
13441333
VectorType readType;
13451334
Type elemType = getElementTypeOrSelf(opOperand->get());
@@ -1369,7 +1358,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13691358
Operation *read = rewriter.create<vector::TransferReadOp>(
13701359
loc, readType, opOperand->get(), indices, readMap,
13711360
ArrayRef<bool>(inBounds));
1372-
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1361+
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
13731362
Value readValue = read->getResult(0);
13741363

13751364
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access

0 commit comments

Comments
 (0)