Skip to content

Commit da8a8e9

Browse files
author
Mahesh Ravishankar
committed
[mlir][Linalg] Move patterns to remove dead arguments and results out of canonicalization.
The patterns to remove dead arguments and results of `linalg.generic` operations are not necessarily canonicalizations. Instead a new entry point `populateEraseUnusedOperandsAndResults` is added to allow using these patterns when needed. The transformations that rely on this pattern for cleanup now include these patterns explicitly. Differential Revision: https://reviews.llvm.org/D138085
1 parent 8474a20 commit da8a8e9

File tree

12 files changed

+454
-409
lines changed

12 files changed

+454
-409
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns,
4848

4949
/// Populate patterns for splitting a `LinalgOp` with multiple statements within
5050
/// its payload into multiple `GenericOp` that have a single statement.
51-
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns);
51+
/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments
52+
/// and results from the generated decomposed ops. This is default `true` since
53+
/// the core decomposition patterns relies on these clean up patterns. It is set
54+
/// to false only for testing purposes.
55+
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns,
56+
bool removeDeadArgsAndResults = true);
5257

5358
/// Populate patterns for vectorizing low-D convolution ops. This is a step in
5459
/// progressive lowering for convolution ops, it assume high-D convolution ops
@@ -76,6 +81,10 @@ void populateElementwiseOpsFusionPatterns(
7681
RewritePatternSet &patterns,
7782
const ControlFusionFn &controlElementwiseOpFusion);
7883

84+
/// Pattern to remove dead operands and results of `linalg.generic` operations.
85+
/// This is effectively DCE for a linalg op.
86+
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
87+
7988
/// Function type to control generic op dimension collapsing. It is expected
8089
/// to return an array of `ReassociationIndices` representing dimensions that
8190
/// should be merged.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 339 deletions
Original file line numberDiff line numberDiff line change
@@ -871,285 +871,10 @@ void GenericOp::getEffects(
871871
getDpsInputOperands(), getDpsInitOperands());
872872
}
873873

874-
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
875-
if (!result.use_empty())
876-
return false;
877-
// If out operand not used in payload, we can drop it.
878-
OpOperand *outputOpOperand =
879-
genericOp.getDpsInitOperand(result.getResultNumber());
880-
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
881-
return true;
882-
883-
// The out operand that is part of a payload can be dropped if
884-
// these conditions are met:
885-
// - Result from out operand is dead.
886-
// - User of arg is yield.
887-
// - outArg data is not being used by other outArgs.
888-
889-
// Check block arg and cycle from out operand has a single use.
890-
BlockArgument outputArg =
891-
genericOp.getRegionOutputArgs()[result.getResultNumber()];
892-
if (!outputArg.hasOneUse())
893-
return false;
894-
Operation *argUserOp = *outputArg.user_begin();
895-
896-
// Check argUser has no other use.
897-
if (!argUserOp->use_empty())
898-
return false;
899-
900-
// Check that argUser is a yield.
901-
auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
902-
if (!yieldOp)
903-
return false;
904-
905-
// Check outArg data is not being used by other outArgs.
906-
if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
907-
return false;
908-
909-
return true;
910-
}
911-
912874
LogicalResult GenericOp::verify() { return success(); }
913875

914876
namespace {
915877

916-
struct DeduplicateAndRemoveDeadOperandsAndResults
917-
: public OpRewritePattern<GenericOp> {
918-
using OpRewritePattern<GenericOp>::OpRewritePattern;
919-
920-
LogicalResult matchAndRewrite(GenericOp genericOp,
921-
PatternRewriter &rewriter) const override {
922-
// Create a map from argument position in the original op to the argument
923-
// position in the new op. If the argument is dropped it wont have an entry.
924-
SmallVector<OpOperand *> droppedOpOperands;
925-
926-
// Information needed to build the new op.
927-
SmallVector<Value> newInputOperands, newOutputOperands;
928-
SmallVector<AffineMap> newIndexingMaps;
929-
930-
// Gather information about duplicate input operands.
931-
llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
932-
deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
933-
newIndexingMaps);
934-
935-
// Gather information about the dropped outputs.
936-
llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
937-
deduplicateOutputOperands(genericOp, droppedOpOperands,
938-
newOutputOperands, newIndexingMaps);
939-
940-
// Check if there is any change to operands.
941-
if (newInputOperands.size() + newOutputOperands.size() ==
942-
genericOp->getNumOperands())
943-
return failure();
944-
945-
// Create the new op with the body being empty.
946-
Location loc = genericOp.getLoc();
947-
SmallVector<Type> newResultTypes;
948-
for (Value v : newOutputOperands)
949-
if (v.getType().isa<TensorType>())
950-
newResultTypes.push_back(v.getType());
951-
auto newOp = rewriter.create<GenericOp>(
952-
loc, newResultTypes, newInputOperands, newOutputOperands,
953-
rewriter.getAffineMapArrayAttr(newIndexingMaps),
954-
genericOp.getIteratorTypes(), genericOp.getDocAttr(),
955-
genericOp.getLibraryCallAttr(),
956-
[](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
957-
return;
958-
});
959-
// Copy over unknown attributes. They might be load bearing for some flow.
960-
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
961-
for (NamedAttribute kv : genericOp->getAttrs())
962-
if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
963-
newOp->setAttr(kv.getName(), kv.getValue());
964-
965-
// Fix up the payload of the canonicalized operation.
966-
populateOpPayload(genericOp, newOp, origInsToNewInsPos,
967-
origOutsToNewOutsPos, rewriter);
968-
969-
// Replace all live uses of the op.
970-
SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
971-
for (const auto &result : llvm::enumerate(genericOp.getResults())) {
972-
auto it = origOutsToNewOutsPos.find(result.index());
973-
if (it == origOutsToNewOutsPos.end())
974-
continue;
975-
replacementsVals[result.index()] = newOp.getResult(it->second);
976-
}
977-
rewriter.replaceOp(genericOp, replacementsVals);
978-
return success();
979-
}
980-
981-
private:
982-
// Deduplicate input operands, and return the
983-
// - Mapping from operand position in the original op, to operand position in
984-
// the canonicalized op.
985-
// - The preserved input operands list (by reference).
986-
llvm::SmallDenseMap<unsigned, unsigned>
987-
deduplicateInputOperands(GenericOp genericOp,
988-
SmallVector<OpOperand *> &droppedOpOperands,
989-
SmallVector<Value> &newInputOperands,
990-
SmallVector<AffineMap> &newIndexingMaps) const {
991-
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
992-
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
993-
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
994-
OpOperand *inputOpOperand = en.value();
995-
// Check if operand is dead and if dropping the indexing map makes the
996-
// loops to shape computation invalid.
997-
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
998-
// Add the current operands to the list of potentially droppable
999-
// operands. If it cannot be dropped, this needs to be popped back.
1000-
droppedOpOperands.push_back(inputOpOperand);
1001-
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
1002-
continue;
1003-
droppedOpOperands.pop_back();
1004-
}
1005-
1006-
// Check if this operand is a duplicate.
1007-
AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
1008-
auto it = dedupedInputs.find(
1009-
std::make_pair(inputOpOperand->get(), indexingMap));
1010-
if (it != dedupedInputs.end()) {
1011-
origToNewPos[en.index()] = it->second;
1012-
droppedOpOperands.push_back(inputOpOperand);
1013-
continue;
1014-
}
1015-
1016-
// This is a preserved argument.
1017-
origToNewPos[en.index()] = newInputOperands.size();
1018-
dedupedInputs[{inputOpOperand->get(), indexingMap}] =
1019-
newInputOperands.size();
1020-
newInputOperands.push_back(inputOpOperand->get());
1021-
newIndexingMaps.push_back(indexingMap);
1022-
}
1023-
return origToNewPos;
1024-
}
1025-
1026-
// Deduplicate output operands, and return the
1027-
// - Mapping from operand position in the original op, to operand position in
1028-
// the canonicalized op.
1029-
// - The preserved output operands list (by reference).
1030-
llvm::SmallDenseMap<unsigned, unsigned>
1031-
deduplicateOutputOperands(GenericOp genericOp,
1032-
SmallVector<OpOperand *> &droppedOpOperands,
1033-
SmallVector<Value> &newOutputOperands,
1034-
SmallVector<AffineMap> &newIndexingMaps) const {
1035-
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
1036-
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
1037-
dedupedOutpts;
1038-
// If the op doesnt have tensor semantics, keep all the outputs as
1039-
// preserved.
1040-
if (!genericOp.hasTensorSemantics()) {
1041-
for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) {
1042-
origToNewPos[en.index()] = newOutputOperands.size();
1043-
newOutputOperands.push_back(en.value()->get());
1044-
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
1045-
}
1046-
return origToNewPos;
1047-
}
1048-
// Output argument can be dropped if the result has
1049-
// - no users, and
1050-
// - it is not used in the payload, and
1051-
// - the corresponding indexing maps are not needed for loop bound
1052-
// computation.
1053-
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
1054-
for (const auto &outputOpOperand :
1055-
llvm::enumerate(genericOp.getDpsInitOperands())) {
1056-
OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
1057-
AffineMap indexingMap =
1058-
genericOp.getMatchingIndexingMap(outputOpOperand.value());
1059-
auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1060-
yieldOp->getOperand(outputOpOperand.index()));
1061-
if (isResultValueDead(genericOp, result)) {
1062-
// Check if the opoperand can be dropped without affecting loop
1063-
// bound computation. Add the operand to the list of dropped op
1064-
// operand for checking. If it cannot be dropped, need to pop the
1065-
// value back.
1066-
droppedOpOperands.push_back(outputOpOperand.value());
1067-
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1068-
continue;
1069-
}
1070-
droppedOpOperands.pop_back();
1071-
}
1072-
1073-
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1074-
// The out operand can also be dropped if it is computed redundantly
1075-
// by another result, the conditions for that are
1076-
// - The same operand is used as the out operand
1077-
// - The same indexing map is used
1078-
// - The same yield value is used.
1079-
auto it = dedupedOutpts.find(key);
1080-
if (it != dedupedOutpts.end()) {
1081-
origToNewPos[outputOpOperand.index()] = it->second;
1082-
droppedOpOperands.push_back(outputOpOperand.value());
1083-
continue;
1084-
}
1085-
}
1086-
1087-
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1088-
dedupedOutpts[key] = newOutputOperands.size();
1089-
newOutputOperands.push_back(outputOpOperand.value()->get());
1090-
newIndexingMaps.push_back(
1091-
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
1092-
}
1093-
return origToNewPos;
1094-
}
1095-
1096-
// Populate the body of the canonicalized operation.
1097-
void populateOpPayload(
1098-
GenericOp genericOp, GenericOp newOp,
1099-
const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
1100-
const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
1101-
PatternRewriter &rewriter) const {
1102-
// Merge the body of the original op with the new op.
1103-
Block *newOpBlock = &newOp.getRegion().front();
1104-
assert(newOpBlock->empty() && "expected new op to have an empty payload");
1105-
Block *origOpBlock = &genericOp.getRegion().front();
1106-
SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
1107-
1108-
// Replace all arguments in the original op, with arguments from the
1109-
// canonicalized op.
1110-
auto updateReplacements =
1111-
[&](OpOperandVector &origOperands, OpOperandVector &newOperands,
1112-
const llvm::SmallDenseMap<unsigned, unsigned> &map) {
1113-
for (const auto &origOperand : llvm::enumerate(origOperands)) {
1114-
auto it = map.find(origOperand.index());
1115-
if (it == map.end())
1116-
continue;
1117-
OpOperand *newOperand = newOperands[it->second];
1118-
replacements[origOperand.value()->getOperandNumber()] =
1119-
newOpBlock->getArgument(newOperand->getOperandNumber());
1120-
}
1121-
};
1122-
1123-
OpOperandVector origInputOperands = genericOp.getDpsInputOperands();
1124-
OpOperandVector newInputOperands = newOp.getDpsInputOperands();
1125-
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
1126-
1127-
OpOperandVector origOutputOperands = genericOp.getDpsInitOperands();
1128-
OpOperandVector newOutputOperands = newOp.getDpsInitOperands();
1129-
updateReplacements(origOutputOperands, newOutputOperands,
1130-
origOutsToNewOutsPos);
1131-
1132-
// Drop the unused yield args.
1133-
if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
1134-
OpBuilder::InsertionGuard g(rewriter);
1135-
YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
1136-
rewriter.setInsertionPoint(origYieldOp);
1137-
1138-
SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
1139-
for (const auto &yieldOpOperands :
1140-
llvm::enumerate(origYieldOp.getValues())) {
1141-
auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
1142-
if (it == origOutsToNewOutsPos.end())
1143-
continue;
1144-
newYieldVals[it->second] = yieldOpOperands.value();
1145-
}
1146-
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
1147-
}
1148-
1149-
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
1150-
}
1151-
};
1152-
1153878
/// Remove generic operations (on tensors) that are just copying
1154879
/// the values from inputs to the results. Requirements are
1155880
/// 1) All iterator types are parallel
@@ -1227,74 +952,11 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1227952
}
1228953
};
1229954

1230-
/// Remove unused cycles.
1231-
/// We can remove unused cycle within a payload of generic region
1232-
/// if these conditions are met:
1233-
/// - Result from out operand is dead.
1234-
/// - Block arg from out operand has a single use in the %cycle
1235-
/// instruction.
1236-
/// - Cycle has a single use and it is in yield.
1237-
struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
1238-
using OpRewritePattern<GenericOp>::OpRewritePattern;
1239-
1240-
LogicalResult matchAndRewrite(GenericOp genericOp,
1241-
PatternRewriter &rewriter) const override {
1242-
1243-
// If the op doesnt have tensor semantics, preserve the outputs as is.
1244-
if (!genericOp.hasTensorSemantics())
1245-
return failure();
1246-
1247-
bool hasRemovedCycles = false;
1248-
// Iterate over output operands and remove any unused cycles.
1249-
for (const auto &outputOpOperand :
1250-
llvm::enumerate(genericOp.getDpsInitOperands())) {
1251-
1252-
// Check that result from out operand is dead.
1253-
Value result = genericOp.getResult(outputOpOperand.index());
1254-
if (!result.use_empty())
1255-
continue;
1256-
1257-
// Check that outputArg has one use in cycle.
1258-
BlockArgument outputArg =
1259-
genericOp.getRegionOutputArgs()[outputOpOperand.index()];
1260-
if (!outputArg.hasOneUse())
1261-
continue;
1262-
1263-
// Check cycle has at most one use.
1264-
Operation *cycleOp = *outputArg.user_begin();
1265-
if (!cycleOp->hasOneUse())
1266-
continue;
1267-
1268-
// Check that the cycleUser is a yield.
1269-
Operation *cycleUserOp = *cycleOp->user_begin();
1270-
if (!isa<linalg::YieldOp>(cycleUserOp))
1271-
continue;
1272-
1273-
// Check that argIndex matches yieldIndex, else data is being used.
1274-
if (cycleUserOp->getOperand(outputOpOperand.index()) !=
1275-
cycleOp->getResult(0))
1276-
continue;
1277-
1278-
// Directly replace the cycle with the blockArg such that
1279-
// Deduplicate pattern can eliminate it along with unused yield.
1280-
rewriter.replaceOp(cycleOp, outputArg);
1281-
rewriter.updateRootInPlace(genericOp, [] {});
1282-
hasRemovedCycles = true;
1283-
}
1284-
1285-
if (hasRemovedCycles) {
1286-
return success();
1287-
}
1288-
1289-
return failure();
1290-
}
1291-
};
1292955
} // namespace
1293956

1294957
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1295958
MLIRContext *context) {
1296-
results.add<DeduplicateAndRemoveDeadOperandsAndResults,
1297-
EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
959+
results.add<EraseIdentityGenericOp>(context);
1298960
}
1299961

1300962
LogicalResult GenericOp::fold(ArrayRef<Attribute>,

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
88
DropUnitDims.cpp
99
ElementwiseOpFusion.cpp
1010
ElementwiseToLinalg.cpp
11+
EraseUnusedOperandsAndResults.cpp
1112
FusePadOpWithLinalgProducer.cpp
1213
Fusion.cpp
1314
FusionOnTensors.cpp

0 commit comments

Comments
 (0)