Skip to content

Commit 2c50fc3

Browse files
manupakjoaosaffran
authored andcommitted
[MLIR][Vector] Add support for inner-parallel masked multi-reductions (llvm#126722)
This commit adds support to lower inner-parallel flavor of masked vector multi-reductions.
1 parent 6ce1450 commit 2c50fc3

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,6 @@ struct TwoDimMultiReductionToElementWise
308308

309309
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310310
PatternRewriter &rewriter) const override {
311-
auto maskableOp =
312-
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313-
if (maskableOp.isMasked())
314-
// TODO: Support masking.
315-
return failure();
316-
317311
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
318312
// Rank-2 ["parallel", "reduce"] or bail.
319313
if (srcRank != 2)
@@ -330,15 +324,33 @@ struct TwoDimMultiReductionToElementWise
330324
if (!elementType.isIntOrIndexOrFloat())
331325
return failure();
332326

327+
OpBuilder::InsertionGuard guard(rewriter);
328+
auto maskableOp =
329+
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
330+
Operation *rootOp;
331+
Value mask = nullptr;
332+
if (maskableOp.isMasked()) {
333+
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
334+
rootOp = maskableOp.getMaskingOp();
335+
mask = maskableOp.getMaskingOp().getMask();
336+
} else {
337+
rootOp = multiReductionOp;
338+
}
339+
333340
Value result = multiReductionOp.getAcc();
334341
for (int64_t i = 0; i < srcShape[0]; i++) {
335342
auto operand = rewriter.create<vector::ExtractOp>(
336343
loc, multiReductionOp.getSource(), i);
337-
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
338-
operand, result);
344+
Value extractMask = nullptr;
345+
if (mask) {
346+
extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
347+
}
348+
result =
349+
makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350+
result, /*fastmath=*/nullptr, extractMask);
339351
}
340352

341-
rewriter.replaceOp(multiReductionOp, result);
353+
rewriter.replaceOp(rootOp, result);
342354
return success();
343355
}
344356
};

mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
4141
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
4242
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
4343
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
44+
45+
// -----
46+
47+
func.func @vector_multi_reduction_masked(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
48+
%0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
49+
return %0 : vector<2xf32>
50+
}
51+
52+
// ALL-LABEL: func @vector_multi_reduction_masked
53+
// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
54+
// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
55+
// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
56+
// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
57+
// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
58+
// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
59+
// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
60+
// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
61+
// INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
62+
// INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
63+
// INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>

0 commit comments

Comments
 (0)