Skip to content

Commit afc3756

Browse files
committed
[mlir][vector] Masking support for reductions in Linalg vectorizer
This patch enables vectorization of reductions in Linalg vectorizer using the vector.mask operation. It also introduces the logic to slice and propagate the vector mask of a masked multi-reduction to their respective lowering operations. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D141571
1 parent 60dd937 commit afc3756

File tree

7 files changed

+351
-66
lines changed

7 files changed

+351
-66
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,20 @@ inline bool isReductionIterator(Attribute attr) {
203203
return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
204204
}
205205

206+
//===----------------------------------------------------------------------===//
207+
// Vector Masking Utilities
208+
//===----------------------------------------------------------------------===//
209+
210+
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
211+
/// as masked operation.
212+
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
213+
214+
/// Creates a vector.mask operation around a maskable operation. Returns the
215+
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
216+
/// maskable operation itself.
217+
Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
218+
Value mask);
219+
206220
} // namespace vector
207221
} // namespace mlir
208222

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def Vector_MultiDimReductionOp :
340340
PredOpTrait<"source operand and result have same element type",
341341
TCresVTEtIsSameAsOpBase<0, 0>>,
342342
DeclareOpInterfaceMethods<InferTypeOpInterface>,
343+
DeclareOpInterfaceMethods<MaskableOpInterface>,
343344
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
344345
["getShapeForUnroll"]>]>,
345346
Arguments<(ins Vector_CombiningKindAttr:$kind,
@@ -2338,16 +2339,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
23382339

23392340
let skipDefaultBuilders = 1;
23402341
let builders = [
2341-
OpBuilder<(ins "Value":$mask,
2342-
CArg<"function_ref<void(OpBuilder &, Location)>",
2343-
"buildTerminatedBody">:$maskRegion)>,
2344-
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
2345-
CArg<"function_ref<void(OpBuilder &, Location)>",
2346-
"buildTerminatedBody">:$maskRegion)>,
2347-
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
2348-
"Value":$passthru,
2349-
CArg<"function_ref<void(OpBuilder &, Location)>",
2350-
"buildTerminatedBody">:$maskRegion)>
2342+
OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp,
2343+
CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
2344+
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp,
2345+
CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
2346+
OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru,
2347+
"Operation *":$maskableOp,
2348+
CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>
23512349
];
23522350

23532351
let extraClassDeclaration = [{

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -292,25 +292,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
292292

293293
// Wrap the operation with a new `vector.mask` and update D-U chain.
294294
assert(opToMask && "Expected a valid operation to mask");
295-
auto opResults = opToMask->getResultTypes();
296-
auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
297-
Block *insBlock = builder.getInsertionBlock();
298-
// Create a block, put an op in that block. Look for a utility.
299-
// Maybe in conversion pattern rewriter. Way to avoid splice.
300-
// Set insertion point.
301-
insBlock->getOperations().splice(
302-
insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
303-
builder.create<vector::YieldOp>(loc, opToMask->getResults());
304-
};
305-
// TODO: Allow multiple results in vector.mask.
306-
auto maskOp =
307-
opResults.empty()
308-
? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
309-
createRegionMask)
310-
: rewriter.create<vector::MaskOp>(opToMask->getLoc(),
311-
opToMask->getResultTypes().front(),
312-
mask, createRegionMask);
313-
295+
auto maskOp = cast<vector::MaskOp>(
296+
mlir::vector::maskOperation(rewriter, opToMask, mask));
314297
Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
315298

316299
for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
@@ -440,17 +423,16 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
440423
/// initial value.buildMultiDimReduce
441424
// Note: this is a true builder that notifies the OpBuilder listener.
442425
// TODO: Consider moving as a static helper on the ReduceOp.
443-
static Operation *buildMultiDimReduce(OpBuilder &b,
444-
Operation *reduceOp, Value valueToReduce,
445-
Value acc,
446-
const SmallVector<bool> &reductionMask) {
426+
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
427+
Value valueToReduce, Value acc,
428+
ArrayRef<bool> dimsToMask) {
447429
auto maybeKind = getCombinerOpKind(reduceOp);
448430
assert(maybeKind && "Failed precondition: could not get reduction kind");
449431
return b.create<vector::MultiDimReductionOp>(
450-
reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
432+
reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
451433
}
452434

453-
static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
435+
static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
454436
return llvm::to_vector(
455437
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
456438
}
@@ -701,8 +683,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
701683
if (!reduceType ||
702684
(outputType && reduceType.getShape() == outputType.getShape()))
703685
return nullptr;
704-
SmallVector<bool> reductionMask = getReductionMask(linalgOp);
705-
return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
686+
SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
687+
return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
706688
}
707689

708690
/// Generic vectorization for a single operation `op`, given already vectorized
@@ -972,11 +954,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
972954
}
973955

974956
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
975-
// TODO: Masking only supports dynamic generic ops without reductions for now.
976-
if (!isElementwise(op) &&
977-
llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
978-
return itType != utils::IteratorType::parallel;
979-
}))
957+
// TODO: Masking only supports dynamic generic ops for now.
958+
if (!isa<linalg::GenericOp>(op))
980959
return failure();
981960

982961
// TODO: 0-d vectors are not supported yet.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@ LogicalResult MultiDimReductionOp::verify() {
342342
return success();
343343
}
344344

345+
/// Returns the mask type expected by this operation.
346+
Type MultiDimReductionOp::getExpectedMaskType() {
347+
auto vecType = getSourceVectorType();
348+
return VectorType::get(vecType.getShape(),
349+
IntegerType::get(vecType.getContext(), /*width=*/1));
350+
}
351+
345352
namespace {
346353
// Only unit dimensions that are being reduced are folded. If the dimension is
347354
// unit, but not reduced, it is not folded, thereby keeping the output type the
@@ -5276,29 +5283,31 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
52765283

52775284
void MaskOp::build(
52785285
OpBuilder &builder, OperationState &result, Value mask,
5279-
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5286+
Operation *maskableOp,
5287+
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
52805288
assert(maskRegionBuilder &&
52815289
"builder callback for 'maskRegion' must be present");
52825290

52835291
result.addOperands(mask);
52845292
OpBuilder::InsertionGuard guard(builder);
52855293
Region *maskRegion = result.addRegion();
52865294
builder.createBlock(maskRegion);
5287-
maskRegionBuilder(builder, result.location);
5295+
maskRegionBuilder(builder, maskableOp);
52885296
}
52895297

52905298
void MaskOp::build(
52915299
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5292-
Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5293-
build(builder, result, resultTypes, mask, /*passthru=*/Value(),
5300+
Value mask, Operation *maskableOp,
5301+
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5302+
build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
52945303
maskRegionBuilder);
52955304
}
52965305

52975306
void MaskOp::build(
5298-
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
5299-
Value mask, Value passthru,
5300-
function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5301-
build(builder, result, mask, maskRegionBuilder);
5307+
OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask,
5308+
Value passthru, Operation *maskableOp,
5309+
function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
5310+
build(builder, result, mask, maskableOp, maskRegionBuilder);
53025311
if (passthru)
53035312
result.addOperands(passthru);
53045313
result.addTypes(resultTypes);
@@ -5738,6 +5747,34 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
57385747
llvm_unreachable("unknown CombiningKind");
57395748
}
57405749

5750+
//===----------------------------------------------------------------------===//
5751+
// Vector Masking Utilities
5752+
//===----------------------------------------------------------------------===//
5753+
5754+
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
5755+
/// as masked operation.
5756+
void mlir::vector::createMaskOpRegion(OpBuilder &builder,
5757+
Operation *maskableOp) {
5758+
assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
5759+
Block *insBlock = builder.getInsertionBlock();
5760+
// Create a block and move the op to that block.
5761+
insBlock->getOperations().splice(
5762+
insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
5763+
builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
5764+
}
5765+
5766+
/// Creates a vector.mask operation around a maskable operation. Returns the
5767+
/// vector.mask operation if the mask provided is valid. Otherwise, returns
5768+
/// the maskable operation itself.
5769+
Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
5770+
Operation *maskableOp, Value mask) {
5771+
if (!mask)
5772+
return maskableOp;
5773+
return rewriter.create<MaskOp>(maskableOp->getLoc(),
5774+
maskableOp->getResultTypes(), mask, maskableOp,
5775+
createMaskOpRegion);
5776+
}
5777+
57415778
//===----------------------------------------------------------------------===//
57425779
// TableGen'd op method definitions
57435780
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)