Skip to content

Commit 7db6c3a

Browse files
committed
fixup! [mlir][Linalg] Refine how broadcast dims are treated
* Move the logic to remove zero from indexing maps to `maskOperation` * Update the input mask name in `maskOperation` to `maybeIndexingMap` - the actual input is always an indexing map extracted from the corresponding linalg Op * Remove the duplicated comment for `maskOperation`
1 parent c726ac3 commit 7db6c3a

File tree

1 file changed

+21
-33
lines changed

1 file changed

+21
-33
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ struct VectorizationState {
224224
/// Masks an operation with the canonical vector mask if the operation needs
225225
/// masking. Returns the masked operation or the original operation if masking
226226
/// is not needed. If provided, the canonical mask for this operation is
227-
/// permuted using `maybeMaskingMap`.
227+
/// permuted using `maybeIndexingMap`.
228228
Operation *
229229
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230-
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
230+
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231231

232232
private:
233233
/// Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422422
return mask;
423423
}
424424

425-
/// Masks an operation with the canonical vector mask if the operation needs
426-
/// masking. Returns the masked operation or the original operation if masking
427-
/// is not needed. If provided, the canonical mask for this operation is
428-
/// permuted using `maybeMaskingMap`.
429425
Operation *
430426
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
431427
LinalgOp linalgOp,
432-
std::optional<AffineMap> maybeMaskingMap) {
428+
std::optional<AffineMap> maybeIndexingMap) {
433429
LDBG("Trying to mask: " << *opToMask << "\n");
434430

431+
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432+
// The Operand indexing map may contain "zero" results, e.g.:
433+
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434+
// When applied to canonical vector shapes like these:
435+
// (1, 16, 16, 4)
436+
// we would get:
437+
// (1, 16, 16, 0)
438+
// Instead, we should extract the following map permutation map for masking:
439+
// (d0, d1, d2, d3) -> (d0, d1, d2)
440+
// This way, the corresponding vector/mask type will be:
441+
// vector<1x16x16xty>
442+
// rather than:
443+
// vector<1x16x16x0xty>
444+
if (maybeIndexingMap)
445+
maybeMaskingMap = maybeIndexingMap->dropZeroResults();
446+
435447
// Create or retrieve mask for this operation.
436448
Value mask =
437449
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -640,21 +652,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
640652
loc, value, outputOperand->get(), ValueRange{});
641653
}
642654

643-
// The operand map may contain "zero" results, e.g.:
644-
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
645-
// When applied to canonical vector shapes like these:
646-
// (1, 16, 16, 4)
647-
// we would get:
648-
// (1, 16, 16, 0)
649-
// Instead, we should extract the following map:
650-
// (d0, d1, d2, d3) -> (d0, d1, d2)
651-
// This way, the corresponding vector/mask type will be:
652-
// vector<1x16x16xty>
653-
// rather than:
654-
// vector<1x16x16x0xty>
655-
AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults();
656-
write =
657-
state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
655+
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
658656

659657
// If masked, set in-bounds to true. Masking guarantees that the access will
660658
// be in-bounds.
@@ -1332,16 +1330,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13321330
// permutation map and masking map.
13331331
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
13341332

1335-
// Remove zeros from indexing map to use it as masking map.
1336-
SmallVector<int64_t> zeroPos;
1337-
auto results = indexingMap.getResults();
1338-
for (const auto &result : llvm::enumerate(results)) {
1339-
if (isa<AffineConstantExpr>(result.value())) {
1340-
zeroPos.push_back(result.index());
1341-
}
1342-
}
1343-
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1344-
13451333
AffineMap readMap;
13461334
VectorType readType;
13471335
Type elemType = getElementTypeOrSelf(opOperand->get());
@@ -1371,7 +1359,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13711359
Operation *read = rewriter.create<vector::TransferReadOp>(
13721360
loc, readType, opOperand->get(), indices, readMap,
13731361
ArrayRef<bool>(inBounds));
1374-
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1362+
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
13751363
Value readValue = read->getResult(0);
13761364

13771365
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access

0 commit comments

Comments
 (0)