Skip to content

Commit ac0fe5d

Browse files
committed
[mlir][linalg] Remove unused payload related OutOpOperand
Some higher level operations such as torch.max generates linalg generic that returns both the index and the value of the max operation. However sometimes not all information is being used. This however blocks vectorization for certain cases which causes performance degradation. This patch aims to fix this issue. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D135388
1 parent 9d69c60 commit ac0fe5d

File tree

2 files changed

+310
-53
lines changed

2 files changed

+310
-53
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 151 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,44 @@ void GenericOp::getEffects(
857857
outputBuffers);
858858
}
859859

860+
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
861+
if (!result.use_empty())
862+
return false;
863+
// If out operand not used in payload, we can drop it.
864+
OpOperand *outputOpOperand =
865+
genericOp.getOutputOperand(result.getResultNumber());
866+
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
867+
return true;
868+
869+
// The out operand that is part of a payload can be dropped if
870+
// these conditions are met:
871+
// - Result from out operand is dead.
872+
// - User of arg is yield.
873+
// - outArg data is not being used by other outArgs.
874+
875+
// Check block arg and cycle from out operand has a single use.
876+
BlockArgument outputArg =
877+
genericOp.getRegionOutputArgs()[result.getResultNumber()];
878+
if (!outputArg.hasOneUse())
879+
return false;
880+
Operation *argUserOp = *outputArg.user_begin();
881+
882+
// Check argUser has no other use.
883+
if (!argUserOp->use_empty())
884+
return false;
885+
886+
// Check that argUser is a yield.
887+
auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
888+
if (!yieldOp)
889+
return false;
890+
891+
// Check outArg data is not being used by other outArgs.
892+
if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
893+
return false;
894+
895+
return true;
896+
}
897+
860898
LogicalResult GenericOp::verify() { return success(); }
861899

862900
namespace {
@@ -995,57 +1033,55 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
9951033
newIndexingMaps.push_back(
9961034
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
9971035
}
998-
} else {
999-
// Output argument can be dropped if the result has
1000-
// - no users, and
1001-
// - it is not used in the payload, and
1002-
// - the corresponding indexing maps are not needed for loop bound
1003-
// computation.
1004-
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
1005-
for (const auto &outputOpOperand :
1006-
llvm::enumerate(genericOp.getOutputOperands())) {
1007-
Value result = genericOp.getResult(outputOpOperand.index());
1008-
AffineMap indexingMap =
1009-
genericOp.getMatchingIndexingMap(outputOpOperand.value());
1010-
auto key =
1011-
std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1012-
yieldOp->getOperand(outputOpOperand.index()));
1013-
1014-
// Do not drop an out if its value is used in the payload.
1015-
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1016-
if (result.use_empty()) {
1017-
// Check if the opoperand can be dropped without affecting loop
1018-
// bound computation. Add the operand to the list of dropped op
1019-
// operand for checking. If it cannot be dropped, need to pop the
1020-
// value back.
1021-
droppedOpOperands.push_back(outputOpOperand.value());
1022-
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1023-
continue;
1024-
}
1025-
droppedOpOperands.pop_back();
1026-
}
1027-
1028-
// The out operand can also be dropped if it is computed redundantly
1029-
// by another result, the conditions for that are
1030-
// - The same operand is used as the out operand
1031-
// - The same indexing map is used
1032-
// - The same yield value is used.
1033-
auto it = dedupedOutpts.find(key);
1034-
if (it != dedupedOutpts.end()) {
1035-
origToNewPos[outputOpOperand.index()] = it->second;
1036-
droppedOpOperands.push_back(outputOpOperand.value());
1037-
continue;
1038-
}
1036+
return origToNewPos;
1037+
}
1038+
// Output argument can be dropped if the result has
1039+
// - no users, and
1040+
// - it is not used in the payload, and
1041+
// - the corresponding indexing maps are not needed for loop bound
1042+
// computation.
1043+
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
1044+
for (const auto &outputOpOperand :
1045+
llvm::enumerate(genericOp.getOutputOperands())) {
1046+
OpResult result = genericOp.getTiedOpResult(outputOpOperand.value());
1047+
AffineMap indexingMap =
1048+
genericOp.getMatchingIndexingMap(outputOpOperand.value());
1049+
auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1050+
yieldOp->getOperand(outputOpOperand.index()));
1051+
assert(genericOp.getNumOutputs() >= outputOpOperand.index() &&
1052+
"Output op idx greater than number of outputs.");
1053+
if (isResultValueDead(genericOp, result)) {
1054+
// Check if the opoperand can be dropped without affecting loop
1055+
// bound computation. Add the operand to the list of dropped op
1056+
// operand for checking. If it cannot be dropped, need to pop the
1057+
// value back.
1058+
droppedOpOperands.push_back(outputOpOperand.value());
1059+
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1060+
continue;
10391061
}
1062+
droppedOpOperands.pop_back();
1063+
}
10401064

1041-
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1042-
dedupedOutpts[key] = newOutputOperands.size();
1043-
newOutputOperands.push_back(outputOpOperand.value()->get());
1044-
newIndexingMaps.push_back(
1045-
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
1065+
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1066+
// The out operand can also be dropped if it is computed redundantly
1067+
// by another result, the conditions for that are
1068+
// - The same operand is used as the out operand
1069+
// - The same indexing map is used
1070+
// - The same yield value is used.
1071+
auto it = dedupedOutpts.find(key);
1072+
if (it != dedupedOutpts.end()) {
1073+
origToNewPos[outputOpOperand.index()] = it->second;
1074+
droppedOpOperands.push_back(outputOpOperand.value());
1075+
continue;
1076+
}
10461077
}
1047-
}
10481078

1079+
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1080+
dedupedOutpts[key] = newOutputOperands.size();
1081+
newOutputOperands.push_back(outputOpOperand.value()->get());
1082+
newIndexingMaps.push_back(
1083+
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
1084+
}
10491085
return origToNewPos;
10501086
}
10511087

@@ -1085,12 +1121,10 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
10851121
updateReplacements(origOutputOperands, newOutputOperands,
10861122
origOutsToNewOutsPos);
10871123

1088-
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
1089-
10901124
// Drop the unused yield args.
10911125
if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
10921126
OpBuilder::InsertionGuard g(rewriter);
1093-
YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
1127+
YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
10941128
rewriter.setInsertionPoint(origYieldOp);
10951129

10961130
SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
@@ -1103,6 +1137,8 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
11031137
}
11041138
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
11051139
}
1140+
1141+
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
11061142
}
11071143
};
11081144

@@ -1178,13 +1214,75 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
11781214
return success();
11791215
}
11801216
};
1217+
1218+
/// Remove unused cycles.
1219+
/// We can remove unused cycle within a payload of generic region
1220+
/// if these conditions are met:
1221+
/// - Result from out operand is dead.
1222+
/// - Block arg from out operand has a single use in the %cycle
1223+
/// instruction.
1224+
/// - Cycle has a single use and it is in yield.
1225+
struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
1226+
using OpRewritePattern<GenericOp>::OpRewritePattern;
1227+
1228+
LogicalResult matchAndRewrite(GenericOp genericOp,
1229+
PatternRewriter &rewriter) const override {
1230+
1231+
// If the op doesnt have tensor semantics, preserve the outputs as is.
1232+
if (!genericOp.hasTensorSemantics())
1233+
return failure();
1234+
1235+
bool hasRemovedCycles = false;
1236+
// Iterate over output operands and remove any unused cycles.
1237+
for (const auto &outputOpOperand :
1238+
llvm::enumerate(genericOp.getOutputOperands())) {
1239+
1240+
// Check that result from out operand is dead.
1241+
Value result = genericOp.getResult(outputOpOperand.index());
1242+
if (!result.use_empty())
1243+
continue;
1244+
1245+
// Check that outputArg has one use in cycle.
1246+
BlockArgument outputArg =
1247+
genericOp.getRegionOutputArgs()[outputOpOperand.index()];
1248+
if (!outputArg.hasOneUse())
1249+
continue;
1250+
1251+
// Check cycle has at most one use.
1252+
Operation *cycleOp = *outputArg.user_begin();
1253+
if (!cycleOp->hasOneUse())
1254+
continue;
1255+
1256+
// Check that the cycleUser is a yield.
1257+
Operation *cycleUserOp = *cycleOp->user_begin();
1258+
if (!isa<linalg::YieldOp>(cycleUserOp))
1259+
continue;
1260+
1261+
// Check that argIndex matches yieldIndex, else data is being used.
1262+
if (cycleUserOp->getOperand(outputOpOperand.index()) !=
1263+
cycleOp->getResult(0))
1264+
continue;
1265+
1266+
// Directly replace the cycle with the blockArg such that
1267+
// Deduplicate pattern can eliminate it along with unused yield.
1268+
rewriter.replaceOp(cycleOp, outputArg);
1269+
rewriter.updateRootInPlace(genericOp, [] {});
1270+
hasRemovedCycles = true;
1271+
}
1272+
1273+
if (hasRemovedCycles) {
1274+
return success();
1275+
}
1276+
1277+
return failure();
1278+
}
1279+
};
11811280
} // namespace
11821281

11831282
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
11841283
MLIRContext *context) {
1185-
results
1186-
.add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
1187-
context);
1284+
results.add<DeduplicateAndRemoveDeadOperandsAndResults,
1285+
EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
11881286
}
11891287

11901288
LogicalResult GenericOp::fold(ArrayRef<Attribute>,

0 commit comments

Comments
 (0)