Skip to content

Commit 431213c

Browse files
authored
[mlir][linalg] Implement patterns for reducing rank of named linalg contraction ops (#95710)
This patch introduces pattern rewrites for reducing the rank of named linalg contraction ops with unit spatial dim(s) to other named contraction ops. For example `linalg.batch_matmul` with batch size 1 -> `linalg.matmul` and `linalg.matmul` with unit LHS spatial dim -> `linalg.vecmat`, etc. These patterns don't support reducing the rank along reduction dimension as those don't convert to other named contraction ops.
1 parent 43d207a commit 431213c

File tree

6 files changed

+605
-0
lines changed

6 files changed

+605
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
17131713
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
17141714
const ControlBlockPackMatmulFn &controlFn);
17151715

1716+
/// Adds patterns that reduce the rank of named contraction ops that have
1717+
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
1718+
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a
1719+
/// `linalg.batch_matmul` with unit batch size will convert to `linalg.matmul`
1720+
/// and a `linalg.matvec` with with unit spatial dim in lhs will convert to a `linalg.dot`.
1721+
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
1722+
17161723
} // namespace linalg
17171724
} // namespace mlir
17181725

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
833833
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
834834
}
835835
};
836+
837+
} // namespace
838+
839+
namespace {
840+
841+
/// Returns reassociation indices for collapsing/expanding a
842+
/// tensor of rank `rank` at position `pos`.
843+
static SmallVector<ReassociationIndices>
844+
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
845+
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
846+
bool lastDim = pos == rank - 1;
847+
if (rank > 2) {
848+
for (int64_t i = 0; i < rank - 1; i++) {
849+
if (i == pos || (lastDim && i == pos - 1))
850+
reassociation[i] = ReassociationIndices{i, i + 1};
851+
else if (i < pos)
852+
reassociation[i] = ReassociationIndices{i};
853+
else
854+
reassociation[i] = ReassociationIndices{i + 1};
855+
}
856+
}
857+
return reassociation;
858+
}
859+
860+
/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
861+
/// If `pos < 0`, then don't collapse.
862+
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
863+
int64_t pos) {
864+
if (pos < 0)
865+
return val;
866+
auto valType = cast<ShapedType>(val.getType());
867+
SmallVector<int64_t> collapsedShape(valType.getShape());
868+
collapsedShape.erase(collapsedShape.begin() + pos);
869+
return collapseValue(
870+
rewriter, val.getLoc(), val, collapsedShape,
871+
getReassociationForReshapeAtDim(valType.getRank(), pos),
872+
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
873+
}
874+
875+
/// Base class for all rank reduction patterns for contraction ops
876+
/// with unit dimensions. All patterns should convert one named op
877+
/// to another named op. Intended to reduce only one iteration space dim
878+
/// at a time.
879+
/// Reducing multiple dims will happen with recusive application of
880+
/// pattern rewrites.
881+
template <typename FromOpTy, typename ToOpTy>
882+
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
883+
using OpRewritePattern<FromOpTy>::OpRewritePattern;
884+
885+
/// Collapse all collapsable operands.
886+
SmallVector<Value>
887+
collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
888+
ArrayRef<int64_t> operandCollapseDims) const {
889+
assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
890+
"expected 3 operands and dims");
891+
return llvm::map_to_vector(
892+
llvm::zip(operands, operandCollapseDims), [&](auto pair) {
893+
return collapseSingletonDimAt(rewriter, std::get<0>(pair),
894+
std::get<1>(pair));
895+
});
896+
}
897+
898+
/// Expand result tensor.
899+
Value expandResult(PatternRewriter &rewriter, Value result,
900+
RankedTensorType expandedType, int64_t dim) const {
901+
return rewriter.create<tensor::ExpandShapeOp>(
902+
result.getLoc(), expandedType, result,
903+
getReassociationForReshapeAtDim(expandedType.getRank(), dim));
904+
}
905+
906+
LogicalResult matchAndRewrite(FromOpTy contractionOp,
907+
PatternRewriter &rewriter) const override {
908+
909+
auto loc = contractionOp.getLoc();
910+
auto inputs = contractionOp.getDpsInputs();
911+
auto inits = contractionOp.getDpsInits();
912+
if (inputs.size() != 2 || inits.size() != 1)
913+
return rewriter.notifyMatchFailure(contractionOp,
914+
"expected 2 inputs and 1 init");
915+
auto lhs = inputs[0];
916+
auto rhs = inputs[1];
917+
auto init = inits[0];
918+
SmallVector<Value> operands{lhs, rhs, init};
919+
920+
SmallVector<int64_t> operandUnitDims;
921+
if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
922+
return rewriter.notifyMatchFailure(contractionOp,
923+
"no reducable dims found");
924+
925+
SmallVector<Value> collapsedOperands =
926+
collapseOperands(rewriter, operands, operandUnitDims);
927+
Value collapsedLhs = collapsedOperands[0];
928+
Value collapsedRhs = collapsedOperands[1];
929+
Value collapsedInit = collapsedOperands[2];
930+
SmallVector<Type, 1> collapsedResultTy;
931+
if (isa<RankedTensorType>(collapsedInit.getType()))
932+
collapsedResultTy.push_back(collapsedInit.getType());
933+
auto collapsedOp = rewriter.create<ToOpTy>(
934+
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
935+
ValueRange{collapsedInit});
936+
for (auto attr : contractionOp->getAttrs()) {
937+
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
938+
continue;
939+
collapsedOp->setAttr(attr.getName(), attr.getValue());
940+
}
941+
942+
auto results = contractionOp.getResults();
943+
assert(results.size() < 2 && "expected at most one result");
944+
if (results.empty()) {
945+
rewriter.replaceOp(contractionOp, collapsedOp);
946+
} else {
947+
rewriter.replaceOp(
948+
contractionOp,
949+
expandResult(rewriter, collapsedOp.getResultTensors()[0],
950+
cast<RankedTensorType>(results[0].getType()),
951+
operandUnitDims[2]));
952+
}
953+
954+
return success();
955+
}
956+
957+
/// Populate `operandUnitDims` with 3 indices indicating the unit dim
958+
/// for each operand that should be collapsed in this pattern. If an
959+
/// operand shouldn't be collapsed, the index should be negative.
960+
virtual LogicalResult
961+
getOperandUnitDims(LinalgOp op,
962+
SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
963+
};
964+
965+
/// Patterns for unbatching batched contraction ops
966+
template <typename FromOpTy, typename ToOpTy>
967+
struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
968+
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
969+
970+
/// Look for unit batch dims to collapse.
971+
LogicalResult
972+
getOperandUnitDims(LinalgOp op,
973+
SmallVectorImpl<int64_t> &operandUnitDims) const override {
974+
FailureOr<ContractionDimensions> maybeContractionDims =
975+
inferContractionDims(op);
976+
if (failed(maybeContractionDims)) {
977+
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
978+
return failure();
979+
}
980+
ContractionDimensions contractionDims = maybeContractionDims.value();
981+
982+
if (contractionDims.batch.size() != 1)
983+
return failure();
984+
auto batchDim = contractionDims.batch[0];
985+
SmallVector<std::pair<Value, unsigned>, 3> bOperands;
986+
op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
987+
if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
988+
return cast<ShapedType>(std::get<0>(pair).getType())
989+
.getShape()[std::get<1>(pair)] != 1;
990+
})) {
991+
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
992+
return failure();
993+
}
994+
995+
operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
996+
std::get<1>(bOperands[1]),
997+
std::get<1>(bOperands[2])};
998+
return success();
999+
}
1000+
};
1001+
1002+
/// Patterns for reducing non-batch dimensions
1003+
template <typename FromOpTy, typename ToOpTy>
1004+
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1005+
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1006+
1007+
/// Helper for determining whether the lhs/init or rhs/init are reduced.
1008+
static bool constexpr reduceLeft =
1009+
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
1010+
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1011+
(std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1012+
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1013+
(std::is_same_v<FromOpTy, MatmulOp> &&
1014+
std::is_same_v<ToOpTy, VecmatOp>) ||
1015+
(std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1016+
std::is_same_v<ToOpTy, VecmatOp>) ||
1017+
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1018+
1019+
/// Look for non-batch spatial dims to collapse.
1020+
LogicalResult
1021+
getOperandUnitDims(LinalgOp op,
1022+
SmallVectorImpl<int64_t> &operandUnitDims) const override {
1023+
FailureOr<ContractionDimensions> maybeContractionDims =
1024+
inferContractionDims(op);
1025+
if (failed(maybeContractionDims)) {
1026+
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1027+
return failure();
1028+
}
1029+
ContractionDimensions contractionDims = maybeContractionDims.value();
1030+
1031+
if constexpr (reduceLeft) {
1032+
auto m = contractionDims.m[0];
1033+
SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1034+
op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1035+
if (mOperands.size() != 2)
1036+
return failure();
1037+
if (llvm::all_of(mOperands, [](auto pair) {
1038+
return cast<ShapedType>(std::get<0>(pair).getType())
1039+
.getShape()[std::get<1>(pair)] == 1;
1040+
})) {
1041+
operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1042+
std::get<1>(mOperands[1])};
1043+
return success();
1044+
}
1045+
} else {
1046+
auto n = contractionDims.n[0];
1047+
SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1048+
op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1049+
if (nOperands.size() != 2)
1050+
return failure();
1051+
if (llvm::all_of(nOperands, [](auto pair) {
1052+
return cast<ShapedType>(std::get<0>(pair).getType())
1053+
.getShape()[std::get<1>(pair)] == 1;
1054+
})) {
1055+
operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1056+
std::get<1>(nOperands[1])};
1057+
return success();
1058+
}
1059+
}
1060+
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1061+
return failure();
1062+
}
1063+
};
1064+
8361065
} // namespace
1066+
1067+
void mlir::linalg::populateContractionOpRankReducingPatterns(
1068+
RewritePatternSet &patterns) {
1069+
MLIRContext *context = patterns.getContext();
1070+
// Unbatching patterns for unit batch size
1071+
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1072+
patterns
1073+
.add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1074+
context);
1075+
patterns
1076+
.add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1077+
context);
1078+
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1079+
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1080+
1081+
// Non-batch rank 1 reducing patterns
1082+
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1083+
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1084+
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1085+
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1086+
// Batch rank 1 reducing patterns
1087+
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1088+
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1089+
patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1090+
context);
1091+
patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1092+
context);
1093+
1094+
// Non-batch rank 0 reducing patterns
1095+
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1096+
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1097+
}

0 commit comments

Comments
 (0)