Skip to content

Commit ee5193c

Browse files
committed
[MLIR][Vector] Add support for inner-parallel masked multi-reductions
This commit adds suppot to lower inner-parallel flavor of masked vector multi-reductions.
1 parent e258bca commit ee5193c

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,16 @@ struct TwoDimMultiReductionToElementWise
310310
PatternRewriter &rewriter) const override {
311311
auto maskableOp =
312312
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313-
if (maskableOp.isMasked())
314-
// TODO: Support masking.
315-
return failure();
313+
314+
Operation *rootOp;
315+
Value mask = nullptr;
316+
if (maskableOp.isMasked()) {
317+
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
318+
rootOp = maskableOp.getMaskingOp();
319+
mask = maskableOp.getMaskingOp().getMask();
320+
} else {
321+
rootOp = multiReductionOp;
322+
}
316323

317324
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
318325
// Rank-2 ["parallel", "reduce"] or bail.
@@ -334,11 +341,15 @@ struct TwoDimMultiReductionToElementWise
334341
for (int64_t i = 0; i < srcShape[0]; i++) {
335342
auto operand = rewriter.create<vector::ExtractOp>(
336343
loc, multiReductionOp.getSource(), i);
344+
Value extractMask = nullptr;
345+
if (mask) {
346+
extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
347+
}
337348
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
338-
operand, result);
349+
operand, result, nullptr, extractMask);
339350
}
340351

341-
rewriter.replaceOp(multiReductionOp, result);
352+
rewriter.replaceOp(rootOp, result);
342353
return success();
343354
}
344355
};

mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
4141
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
4242
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
4343
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
44+
45+
// -----
46+
47+
func.func @vector_masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
48+
%0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
49+
return %0 : vector<2xf32>
50+
}
51+
52+
// ALL-LABEL: func @vector_masked_multi_reduction
53+
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
54+
// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
55+
// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
56+
// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
57+
// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
58+
// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
59+
// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
60+
// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
61+
// INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
62+
// INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
63+
// INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>

0 commit comments

Comments
 (0)