@@ -250,6 +250,32 @@ struct VectorizationState {
250
250
LinalgOp linalgOp,
251
251
std::optional<AffineMap> maybeMaskingMap);
252
252
253
+ // / Check whether this permutation map can be used for masking. At the
254
+ // / moment we only make sure that there are no broadcast dimensions, but this
255
+ // / might change if indexing maps evolve.
256
+ bool isValidMaskingMap (AffineMap maskingMap) {
257
+ return maskingMap.getBroadcastDims ().size () == 0 ;
258
+ }
259
+
260
+ // / Turn the input indexing map into a valid masking map.
261
+ // /
262
+ // / The input indexing map may contain "zero" results, e.g.:
263
+ // / (d0, d1, d2, d3) -> (d2, d1, d0, 0)
264
+ // / Applying such maps to canonical vector shapes like this one:
265
+ // / (1, 16, 16, 4)
266
+ // / would yield an invalid vector shape like this:
267
+ // / (16, 16, 1, 0)
268
+ // / Instead, drop the broadcasting dims that make no sense for masking perm.
269
+ // / maps:
270
+ // / (d0, d1, d2, d3) -> (d2, d1, d0)
271
+ // / This way, the corresponding vector/mask type will be:
272
+ // / vector<16x16x1xty>
273
+ // / rather than this invalid Vector type:
274
+ // / vector<16x16x1x0xty>
275
+ AffineMap getMaskingMapFromIndexingMap (AffineMap &indexingMap) {
276
+ return indexingMap.dropZeroResults ();
277
+ }
278
+
253
279
// Holds the compile-time static sizes of the iteration space to vectorize.
254
280
// Dynamic dimensions are represented using ShapedType::kDynamic.
255
281
SmallVector<int64_t > iterSpaceStaticSizes;
@@ -360,6 +386,10 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
360
386
Value VectorizationState::getOrCreateMaskFor (
361
387
RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
362
388
std::optional<AffineMap> maybeMaskingMap) {
389
+
390
+ assert ((!maybeMaskingMap || isValidMaskingMap (*maybeMaskingMap)) &&
391
+ " Ill-formed masking map." );
392
+
363
393
// No mask is needed if the operation is not maskable.
364
394
auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
365
395
if (!maskableOp)
@@ -429,20 +459,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
429
459
LDBG (" Trying to mask: " << *opToMask << " \n " );
430
460
431
461
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432
- // The Operand indexing map may contain "zero" results, e.g.:
433
- // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434
- // When applied to canonical vector shapes like these:
435
- // (1, 16, 16, 4)
436
- // we would get:
437
- // (1, 16, 16, 0)
438
- // Instead, we should extract the following map permutation map for masking:
439
- // (d0, d1, d2, d3) -> (d0, d1, d2)
440
- // This way, the corresponding vector/mask type will be:
441
- // vector<1x16x16xty>
442
- // rather than:
443
- // vector<1x16x16x0xty>
444
462
if (maybeIndexingMap)
445
- maybeMaskingMap = maybeIndexingMap-> dropZeroResults ( );
463
+ maybeMaskingMap = getMaskingMapFromIndexingMap (*maybeIndexingMap );
446
464
447
465
// Create or retrieve mask for this operation.
448
466
Value mask =
0 commit comments