@@ -385,17 +385,25 @@ struct OneDimMultiReductionToTwoDim
385
385
386
386
LogicalResult matchAndRewrite (vector::MultiDimReductionOp multiReductionOp,
387
387
PatternRewriter &rewriter) const override {
388
- auto maskableOp =
389
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
390
- if (maskableOp.isMasked ())
391
- // TODO: Support masking.
392
- return failure ();
393
-
394
388
auto srcRank = multiReductionOp.getSourceVectorType ().getRank ();
395
389
// Rank-1 or bail.
396
390
if (srcRank != 1 )
397
391
return failure ();
398
392
393
+ // Vector mask setup.
394
+ OpBuilder::InsertionGuard guard (rewriter);
395
+ auto maskableOp =
396
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
397
+ Operation *rootOp;
398
+ Value mask;
399
+ if (maskableOp.isMasked ()) {
400
+ rewriter.setInsertionPoint (maskableOp.getMaskingOp ());
401
+ rootOp = maskableOp.getMaskingOp ();
402
+ mask = maskableOp.getMaskingOp ().getMask ();
403
+ } else {
404
+ rootOp = multiReductionOp;
405
+ }
406
+
399
407
auto loc = multiReductionOp.getLoc ();
400
408
auto srcVectorType = multiReductionOp.getSourceVectorType ();
401
409
auto srcShape = srcVectorType.getShape ();
@@ -408,16 +416,27 @@ struct OneDimMultiReductionToTwoDim
408
416
409
417
// If the unique dim is reduced and we insert a parallel in front, we need a
410
418
// {false, true} mask.
411
- SmallVector<bool , 2 > mask {false , true };
419
+ SmallVector<bool , 2 > reductionMask {false , true };
412
420
413
421
// / vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
414
422
Value cast = rewriter.create <vector::ShapeCastOp>(
415
423
loc, castedType, multiReductionOp.getSource ());
416
424
Value castAcc = rewriter.create <vector::BroadcastOp>(
417
425
loc, accType, multiReductionOp.getAcc ());
418
- Value reduced = rewriter.create <vector::MultiDimReductionOp>(
419
- loc, cast, castAcc, mask, multiReductionOp.getKind ());
420
- rewriter.replaceOpWithNewOp <vector::ExtractOp>(multiReductionOp, reduced,
426
+ Value castMask;
427
+ if (maskableOp.isMasked ()) {
428
+ auto maskType = mask.getType ().cast <ShapedType>();
429
+ auto castMaskType =
430
+ VectorType::get (ArrayRef<int64_t >{1 , maskType.getShape ().back ()},
431
+ maskType.getElementType ());
432
+ castMask = rewriter.create <vector::BroadcastOp>(loc, castMaskType, mask);
433
+ }
434
+
435
+ Operation *newOp = rewriter.create <vector::MultiDimReductionOp>(
436
+ loc, cast, castAcc, reductionMask, multiReductionOp.getKind ());
437
+ newOp = vector::maskOperation (rewriter, newOp, castMask);
438
+
439
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(rootOp, newOp->getResult (0 ),
421
440
ArrayRef<int64_t >{0 });
422
441
return success ();
423
442
}
0 commit comments