Skip to content

Commit b1d8205

Browse files
committed
[mlir][Vector] Add lowering support for 1-D masked multi-reductions
1-D multi-reductions follow a different lowering path (they are converted to 2-D multi-reductions) so masked variants need to be supported explicitly. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D143453
1 parent 08749a9 commit b1d8205

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,25 @@ struct OneDimMultiReductionToTwoDim
385385

386386
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
387387
PatternRewriter &rewriter) const override {
388-
auto maskableOp =
389-
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
390-
if (maskableOp.isMasked())
391-
// TODO: Support masking.
392-
return failure();
393-
394388
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
395389
// Rank-1 or bail.
396390
if (srcRank != 1)
397391
return failure();
398392

393+
// Vector mask setup.
394+
OpBuilder::InsertionGuard guard(rewriter);
395+
auto maskableOp =
396+
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
397+
Operation *rootOp;
398+
Value mask;
399+
if (maskableOp.isMasked()) {
400+
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
401+
rootOp = maskableOp.getMaskingOp();
402+
mask = maskableOp.getMaskingOp().getMask();
403+
} else {
404+
rootOp = multiReductionOp;
405+
}
406+
399407
auto loc = multiReductionOp.getLoc();
400408
auto srcVectorType = multiReductionOp.getSourceVectorType();
401409
auto srcShape = srcVectorType.getShape();
@@ -408,16 +416,27 @@ struct OneDimMultiReductionToTwoDim
408416

409417
// If the unique dim is reduced and we insert a parallel in front, we need a
410418
// {false, true} mask.
411-
SmallVector<bool, 2> mask{false, true};
419+
SmallVector<bool, 2> reductionMask{false, true};
412420

413421
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
414422
Value cast = rewriter.create<vector::ShapeCastOp>(
415423
loc, castedType, multiReductionOp.getSource());
416424
Value castAcc = rewriter.create<vector::BroadcastOp>(
417425
loc, accType, multiReductionOp.getAcc());
418-
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
419-
loc, cast, castAcc, mask, multiReductionOp.getKind());
420-
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
426+
Value castMask;
427+
if (maskableOp.isMasked()) {
428+
auto maskType = mask.getType().cast<ShapedType>();
429+
auto castMaskType =
430+
VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
431+
maskType.getElementType());
432+
castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
433+
}
434+
435+
Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
436+
loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
437+
newOp = vector::maskOperation(rewriter, newOp, castMask);
438+
439+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
421440
ArrayRef<int64_t>{0});
422441
return success();
423442
}

mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,26 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
189189

190190
// -----
191191

192+
func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
193+
%c0 = arith.constant 0 : index
194+
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
195+
%c0_1 = arith.constant 0 : index
196+
%cst = arith.constant 0.000000e+00 : f32
197+
%0 = vector.create_mask %dim : vector<8xi1>
198+
%1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1], %cst {in_bounds = [true]} : tensor<?xf32>, vector<8xf32> } : vector<8xi1> -> vector<8xf32>
199+
%4 = vector.mask %0 { vector.multi_reduction <add>, %1, %cst [0] : vector<8xf32> to f32 } : vector<8xi1> -> f32
200+
return %4 : f32
201+
}
202+
203+
// Verify that a 1-D vector.multi_reduction is transformed into a vector.reduction.
204+
// This transform expands 1-D vectors into 2-D.
205+
206+
// CHECK-LABEL: func.func @vectorize_1d_dynamic_reduction(
207+
// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
208+
// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
209+
210+
// -----
211+
192212
func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
193213
%c0 = arith.constant 0 : index
194214
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>

0 commit comments

Comments
 (0)