Skip to content

[mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops #95710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Adds patterns that reduce the rank of named contraction ops that have
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

} // namespace linalg
} // namespace mlir

Expand Down
261 changes: 261 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};

} // namespace

namespace {

/// Returns reassociation indices for collapsing/expanding a
/// tensor of rank `rank` at position `pos`.
static SmallVector<ReassociationIndices>
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
bool lastDim = pos == rank - 1;
if (rank > 2) {
for (int64_t i = 0; i < rank - 1; i++) {
if (i == pos || (lastDim && i == pos - 1))
reassociation[i] = ReassociationIndices{i, i + 1};
else if (i < pos)
reassociation[i] = ReassociationIndices{i};
else
reassociation[i] = ReassociationIndices{i + 1};
}
}
return reassociation;
}

/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
/// If `pos < 0`, then don't collapse.
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
int64_t pos) {
if (pos < 0)
return val;
auto valType = cast<ShapedType>(val.getType());
SmallVector<int64_t> collapsedShape(valType.getShape());
collapsedShape.erase(collapsedShape.begin() + pos);
return collapseValue(
rewriter, val.getLoc(), val, collapsedShape,
getReassociationForReshapeAtDim(valType.getRank(), pos),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}

/// Base class for all rank reduction patterns for contraction ops
/// with unit dimensions. All patterns should convert one named op
/// to another named op. Intended to reduce only one iteration space dim
/// at a time.
/// Reducing multiple dims will happen with recusive application of
/// pattern rewrites.
template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;

/// Collapse all collapsable operands.
SmallVector<Value>
collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
ArrayRef<int64_t> operandCollapseDims) const {
assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
"expected 3 operands and dims");
return llvm::map_to_vector(
llvm::zip(operands, operandCollapseDims), [&](auto pair) {
return collapseSingletonDimAt(rewriter, std::get<0>(pair),
std::get<1>(pair));
});
}

/// Expand result tensor.
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType, int64_t dim) const {
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationForReshapeAtDim(expandedType.getRank(), dim));
}

LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {

auto loc = contractionOp.getLoc();
auto inputs = contractionOp.getDpsInputs();
auto inits = contractionOp.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
return rewriter.notifyMatchFailure(contractionOp,
"expected 2 inputs and 1 init");
auto lhs = inputs[0];
auto rhs = inputs[1];
auto init = inits[0];
SmallVector<Value> operands{lhs, rhs, init};

SmallVector<int64_t> operandUnitDims;
if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
return rewriter.notifyMatchFailure(contractionOp,
"no reducable dims found");

SmallVector<Value> collapsedOperands =
collapseOperands(rewriter, operands, operandUnitDims);
Value collapsedLhs = collapsedOperands[0];
Value collapsedRhs = collapsedOperands[1];
Value collapsedInit = collapsedOperands[2];
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
auto collapsedOp = rewriter.create<ToOpTy>(
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}

auto results = contractionOp.getResults();
assert(results.size() < 2 && "expected at most one result");
if (results.empty()) {
rewriter.replaceOp(contractionOp, collapsedOp);
} else {
rewriter.replaceOp(
contractionOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
cast<RankedTensorType>(results[0].getType()),
operandUnitDims[2]));
}

return success();
}

/// Populate `operandUnitDims` with 3 indices indicating the unit dim
/// for each operand that should be collapsed in this pattern. If an
/// operand shouldn't be collapsed, the index should be negative.
virtual LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
};

/// Patterns for unbatching batched contraction ops
template <typename FromOpTy, typename ToOpTy>
struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

/// Look for unit batch dims to collapse.
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
FailureOr<ContractionDimensions> maybeContractionDims =
inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
ContractionDimensions contractionDims = maybeContractionDims.value();

if (contractionDims.batch.size() != 1)
return failure();
auto batchDim = contractionDims.batch[0];
SmallVector<std::pair<Value, unsigned>, 3> bOperands;
op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] != 1;
})) {
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
}

operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
std::get<1>(bOperands[1]),
std::get<1>(bOperands[2])};
return success();
}
};

/// Patterns for reducing non-batch dimensions
template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;

/// Helper for determining whether the lhs/init or rhs/init are reduced.
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);

/// Look for non-batch spatial dims to collapse.
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
FailureOr<ContractionDimensions> maybeContractionDims =
inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
ContractionDimensions contractionDims = maybeContractionDims.value();

if constexpr (reduceLeft) {
auto m = contractionDims.m[0];
SmallVector<std::pair<Value, unsigned>, 2> mOperands;
op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
if (mOperands.size() != 2)
return failure();
if (llvm::all_of(mOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
std::get<1>(mOperands[1])};
return success();
}
} else {
auto n = contractionDims.n[0];
SmallVector<std::pair<Value, unsigned>, 2> nOperands;
op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
if (nOperands.size() != 2)
return failure();
if (llvm::all_of(nOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
std::get<1>(nOperands[1])};
return success();
}
}
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
}
};

} // namespace

void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
patterns
.add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
context);
patterns
.add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);

// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
context);

// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}
Loading