Skip to content

Commit f7f51f2

Browse files
authored
[mlir][vector] Clarify the semantics of masking maps (nfc) (#111383)
We use the term "masking map" throughout the Linalg vectorization logic, but we don't really define what it is and how it differs from Linalg indexing maps. This PR clarifies the differnces, makes sure that the new terminology is used consistenty and improves code re-use.
1 parent 9d0616c commit f7f51f2

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,32 @@ struct VectorizationState {
250250
LinalgOp linalgOp,
251251
std::optional<AffineMap> maybeMaskingMap);
252252

253+
/// Check whether this permutation map can be used for masking. At the
254+
/// moment we only make sure that there are no broadcast dimensions, but this
255+
/// might change if indexing maps evolve.
256+
bool isValidMaskingMap(AffineMap maskingMap) {
257+
return maskingMap.getBroadcastDims().size() == 0;
258+
}
259+
260+
/// Turn the input indexing map into a valid masking map.
261+
///
262+
/// The input indexing map may contain "zero" results, e.g.:
263+
/// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264+
/// Applying such maps to canonical vector shapes like this one:
265+
/// (1, 16, 16, 4)
266+
/// would yield an invalid vector shape like this:
267+
/// (16, 16, 1, 0)
268+
/// Instead, drop the broadcasting dims that make no sense for masking perm.
269+
/// maps:
270+
/// (d0, d1, d2, d3) -> (d2, d1, d0)
271+
/// This way, the corresponding vector/mask type will be:
272+
/// vector<16x16x1xty>
273+
/// rather than this invalid Vector type:
274+
/// vector<16x16x1x0xty>
275+
AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
276+
return indexingMap.dropZeroResults();
277+
}
278+
253279
// Holds the compile-time static sizes of the iteration space to vectorize.
254280
// Dynamic dimensions are represented using ShapedType::kDynamic.
255281
SmallVector<int64_t> iterSpaceStaticSizes;
@@ -360,6 +386,10 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
360386
Value VectorizationState::getOrCreateMaskFor(
361387
RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362388
std::optional<AffineMap> maybeMaskingMap) {
389+
390+
assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391+
"Ill-formed masking map.");
392+
363393
// No mask is needed if the operation is not maskable.
364394
auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365395
if (!maskableOp)
@@ -429,20 +459,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
429459
LDBG("Trying to mask: " << *opToMask << "\n");
430460

431461
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432-
// The Operand indexing map may contain "zero" results, e.g.:
433-
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434-
// When applied to canonical vector shapes like these:
435-
// (1, 16, 16, 4)
436-
// we would get:
437-
// (1, 16, 16, 0)
438-
// Instead, we should extract the following map permutation map for masking:
439-
// (d0, d1, d2, d3) -> (d0, d1, d2)
440-
// This way, the corresponding vector/mask type will be:
441-
// vector<1x16x16xty>
442-
// rather than:
443-
// vector<1x16x16x0xty>
444462
if (maybeIndexingMap)
445-
maybeMaskingMap = maybeIndexingMap->dropZeroResults();
463+
maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
446464

447465
// Create or retrieve mask for this operation.
448466
Value mask =

0 commit comments

Comments
 (0)