@@ -308,12 +308,6 @@ struct TwoDimMultiReductionToElementWise
308
308
309
309
LogicalResult matchAndRewrite (vector::MultiDimReductionOp multiReductionOp,
310
310
PatternRewriter &rewriter) const override {
311
- auto maskableOp =
312
- cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
313
- if (maskableOp.isMasked ())
314
- // TODO: Support masking.
315
- return failure ();
316
-
317
311
auto srcRank = multiReductionOp.getSourceVectorType ().getRank ();
318
312
// Rank-2 ["parallel", "reduce"] or bail.
319
313
if (srcRank != 2 )
@@ -330,15 +324,33 @@ struct TwoDimMultiReductionToElementWise
330
324
if (!elementType.isIntOrIndexOrFloat ())
331
325
return failure ();
332
326
327
+ OpBuilder::InsertionGuard guard (rewriter);
328
+ auto maskableOp =
329
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation ());
330
+ Operation *rootOp;
331
+ Value mask = nullptr ;
332
+ if (maskableOp.isMasked ()) {
333
+ rewriter.setInsertionPoint (maskableOp.getMaskingOp ());
334
+ rootOp = maskableOp.getMaskingOp ();
335
+ mask = maskableOp.getMaskingOp ().getMask ();
336
+ } else {
337
+ rootOp = multiReductionOp;
338
+ }
339
+
333
340
Value result = multiReductionOp.getAcc ();
334
341
for (int64_t i = 0 ; i < srcShape[0 ]; i++) {
335
342
auto operand = rewriter.create <vector::ExtractOp>(
336
343
loc, multiReductionOp.getSource (), i);
337
- result = makeArithReduction (rewriter, loc, multiReductionOp.getKind (),
338
- operand, result);
344
+ Value extractMask = nullptr ;
345
+ if (mask) {
346
+ extractMask = rewriter.create <vector::ExtractOp>(loc, mask, i);
347
+ }
348
+ result =
349
+ makeArithReduction (rewriter, loc, multiReductionOp.getKind (), operand,
350
+ result, /* fastmath=*/ nullptr , extractMask);
339
351
}
340
352
341
- rewriter.replaceOp (multiReductionOp , result);
353
+ rewriter.replaceOp (rootOp , result);
342
354
return success ();
343
355
}
344
356
};
0 commit comments