@@ -857,6 +857,44 @@ void GenericOp::getEffects(
857
857
outputBuffers);
858
858
}
859
859
860
+ static bool isResultValueDead (linalg::GenericOp genericOp, OpResult result) {
861
+ if (!result.use_empty ())
862
+ return false ;
863
+ // If out operand not used in payload, we can drop it.
864
+ OpOperand *outputOpOperand =
865
+ genericOp.getOutputOperand (result.getResultNumber ());
866
+ if (!genericOp.payloadUsesValueFromOperand (outputOpOperand))
867
+ return true ;
868
+
869
+ // The out operand that is part of a payload can be dropped if
870
+ // these conditions are met:
871
+ // - Result from out operand is dead.
872
+ // - User of arg is yield.
873
+ // - outArg data is not being used by other outArgs.
874
+
875
+ // Check block arg and cycle from out operand has a single use.
876
+ BlockArgument outputArg =
877
+ genericOp.getRegionOutputArgs ()[result.getResultNumber ()];
878
+ if (!outputArg.hasOneUse ())
879
+ return false ;
880
+ Operation *argUserOp = *outputArg.user_begin ();
881
+
882
+ // Check argUser has no other use.
883
+ if (!argUserOp->use_empty ())
884
+ return false ;
885
+
886
+ // Check that argUser is a yield.
887
+ auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
888
+ if (!yieldOp)
889
+ return false ;
890
+
891
+ // Check outArg data is not being used by other outArgs.
892
+ if (yieldOp.getOperand (result.getResultNumber ()) != outputArg)
893
+ return false ;
894
+
895
+ return true ;
896
+ }
897
+
860
898
LogicalResult GenericOp::verify () { return success (); }
861
899
862
900
namespace {
@@ -995,57 +1033,55 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
995
1033
newIndexingMaps.push_back (
996
1034
genericOp.getMatchingIndexingMap (outputOpOperand.value ()));
997
1035
}
998
- } else {
999
- // Output argument can be dropped if the result has
1000
- // - no users, and
1001
- // - it is not used in the payload, and
1002
- // - the corresponding indexing maps are not needed for loop bound
1003
- // computation.
1004
- auto yieldOp = cast<YieldOp>(genericOp.getBody ()->getTerminator ());
1005
- for (const auto &outputOpOperand :
1006
- llvm::enumerate (genericOp.getOutputOperands ())) {
1007
- Value result = genericOp.getResult (outputOpOperand.index ());
1008
- AffineMap indexingMap =
1009
- genericOp.getMatchingIndexingMap (outputOpOperand.value ());
1010
- auto key =
1011
- std::make_tuple (outputOpOperand.value ()->get (), indexingMap,
1012
- yieldOp->getOperand (outputOpOperand.index ()));
1013
-
1014
- // Do not drop an out if its value is used in the payload.
1015
- if (!genericOp.payloadUsesValueFromOperand (outputOpOperand.value ())) {
1016
- if (result.use_empty ()) {
1017
- // Check if the opoperand can be dropped without affecting loop
1018
- // bound computation. Add the operand to the list of dropped op
1019
- // operand for checking. If it cannot be dropped, need to pop the
1020
- // value back.
1021
- droppedOpOperands.push_back (outputOpOperand.value ());
1022
- if (genericOp.canOpOperandsBeDropped (droppedOpOperands)) {
1023
- continue ;
1024
- }
1025
- droppedOpOperands.pop_back ();
1026
- }
1027
-
1028
- // The out operand can also be dropped if it is computed redundantly
1029
- // by another result, the conditions for that are
1030
- // - The same operand is used as the out operand
1031
- // - The same indexing map is used
1032
- // - The same yield value is used.
1033
- auto it = dedupedOutpts.find (key);
1034
- if (it != dedupedOutpts.end ()) {
1035
- origToNewPos[outputOpOperand.index ()] = it->second ;
1036
- droppedOpOperands.push_back (outputOpOperand.value ());
1037
- continue ;
1038
- }
1036
+ return origToNewPos;
1037
+ }
1038
+ // Output argument can be dropped if the result has
1039
+ // - no users, and
1040
+ // - it is not used in the payload, and
1041
+ // - the corresponding indexing maps are not needed for loop bound
1042
+ // computation.
1043
+ auto yieldOp = cast<YieldOp>(genericOp.getBody ()->getTerminator ());
1044
+ for (const auto &outputOpOperand :
1045
+ llvm::enumerate (genericOp.getOutputOperands ())) {
1046
+ OpResult result = genericOp.getTiedOpResult (outputOpOperand.value ());
1047
+ AffineMap indexingMap =
1048
+ genericOp.getMatchingIndexingMap (outputOpOperand.value ());
1049
+ auto key = std::make_tuple (outputOpOperand.value ()->get (), indexingMap,
1050
+ yieldOp->getOperand (outputOpOperand.index ()));
1051
+ assert (genericOp.getNumOutputs () >= outputOpOperand.index () &&
1052
+ " Output op idx greater than number of outputs." );
1053
+ if (isResultValueDead (genericOp, result)) {
1054
+ // Check if the opoperand can be dropped without affecting loop
1055
+ // bound computation. Add the operand to the list of dropped op
1056
+ // operand for checking. If it cannot be dropped, need to pop the
1057
+ // value back.
1058
+ droppedOpOperands.push_back (outputOpOperand.value ());
1059
+ if (genericOp.canOpOperandsBeDropped (droppedOpOperands)) {
1060
+ continue ;
1039
1061
}
1062
+ droppedOpOperands.pop_back ();
1063
+ }
1040
1064
1041
- origToNewPos[outputOpOperand.index ()] = newOutputOperands.size ();
1042
- dedupedOutpts[key] = newOutputOperands.size ();
1043
- newOutputOperands.push_back (outputOpOperand.value ()->get ());
1044
- newIndexingMaps.push_back (
1045
- genericOp.getMatchingIndexingMap (outputOpOperand.value ()));
1065
+ if (!genericOp.payloadUsesValueFromOperand (outputOpOperand.value ())) {
1066
+ // The out operand can also be dropped if it is computed redundantly
1067
+ // by another result, the conditions for that are
1068
+ // - The same operand is used as the out operand
1069
+ // - The same indexing map is used
1070
+ // - The same yield value is used.
1071
+ auto it = dedupedOutpts.find (key);
1072
+ if (it != dedupedOutpts.end ()) {
1073
+ origToNewPos[outputOpOperand.index ()] = it->second ;
1074
+ droppedOpOperands.push_back (outputOpOperand.value ());
1075
+ continue ;
1076
+ }
1046
1077
}
1047
- }
1048
1078
1079
+ origToNewPos[outputOpOperand.index ()] = newOutputOperands.size ();
1080
+ dedupedOutpts[key] = newOutputOperands.size ();
1081
+ newOutputOperands.push_back (outputOpOperand.value ()->get ());
1082
+ newIndexingMaps.push_back (
1083
+ genericOp.getMatchingIndexingMap (outputOpOperand.value ()));
1084
+ }
1049
1085
return origToNewPos;
1050
1086
}
1051
1087
@@ -1085,12 +1121,10 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
1085
1121
updateReplacements (origOutputOperands, newOutputOperands,
1086
1122
origOutsToNewOutsPos);
1087
1123
1088
- rewriter.mergeBlocks (origOpBlock, newOpBlock, replacements);
1089
-
1090
1124
// Drop the unused yield args.
1091
1125
if (newOp.getNumOutputs () != genericOp.getNumOutputs ()) {
1092
1126
OpBuilder::InsertionGuard g (rewriter);
1093
- YieldOp origYieldOp = cast<YieldOp>(newOpBlock ->getTerminator ());
1127
+ YieldOp origYieldOp = cast<YieldOp>(origOpBlock ->getTerminator ());
1094
1128
rewriter.setInsertionPoint (origYieldOp);
1095
1129
1096
1130
SmallVector<Value> newYieldVals (newOp.getNumOutputs (), nullptr );
@@ -1103,6 +1137,8 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
1103
1137
}
1104
1138
rewriter.replaceOpWithNewOp <YieldOp>(origYieldOp, newYieldVals);
1105
1139
}
1140
+
1141
+ rewriter.mergeBlocks (origOpBlock, newOpBlock, replacements);
1106
1142
}
1107
1143
};
1108
1144
@@ -1178,13 +1214,75 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1178
1214
return success ();
1179
1215
}
1180
1216
};
1217
+
1218
+ // / Remove unused cycles.
1219
+ // / We can remove unused cycle within a payload of generic region
1220
+ // / if these conditions are met:
1221
+ // / - Result from out operand is dead.
1222
+ // / - Block arg from out operand has a single use in the %cycle
1223
+ // / instruction.
1224
+ // / - Cycle has a single use and it is in yield.
1225
+ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern <GenericOp> {
1226
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
1227
+
1228
+ LogicalResult matchAndRewrite (GenericOp genericOp,
1229
+ PatternRewriter &rewriter) const override {
1230
+
1231
+ // If the op doesnt have tensor semantics, preserve the outputs as is.
1232
+ if (!genericOp.hasTensorSemantics ())
1233
+ return failure ();
1234
+
1235
+ bool hasRemovedCycles = false ;
1236
+ // Iterate over output operands and remove any unused cycles.
1237
+ for (const auto &outputOpOperand :
1238
+ llvm::enumerate (genericOp.getOutputOperands ())) {
1239
+
1240
+ // Check that result from out operand is dead.
1241
+ Value result = genericOp.getResult (outputOpOperand.index ());
1242
+ if (!result.use_empty ())
1243
+ continue ;
1244
+
1245
+ // Check that outputArg has one use in cycle.
1246
+ BlockArgument outputArg =
1247
+ genericOp.getRegionOutputArgs ()[outputOpOperand.index ()];
1248
+ if (!outputArg.hasOneUse ())
1249
+ continue ;
1250
+
1251
+ // Check cycle has at most one use.
1252
+ Operation *cycleOp = *outputArg.user_begin ();
1253
+ if (!cycleOp->hasOneUse ())
1254
+ continue ;
1255
+
1256
+ // Check that the cycleUser is a yield.
1257
+ Operation *cycleUserOp = *cycleOp->user_begin ();
1258
+ if (!isa<linalg::YieldOp>(cycleUserOp))
1259
+ continue ;
1260
+
1261
+ // Check that argIndex matches yieldIndex, else data is being used.
1262
+ if (cycleUserOp->getOperand (outputOpOperand.index ()) !=
1263
+ cycleOp->getResult (0 ))
1264
+ continue ;
1265
+
1266
+ // Directly replace the cycle with the blockArg such that
1267
+ // Deduplicate pattern can eliminate it along with unused yield.
1268
+ rewriter.replaceOp (cycleOp, outputArg);
1269
+ rewriter.updateRootInPlace (genericOp, [] {});
1270
+ hasRemovedCycles = true ;
1271
+ }
1272
+
1273
+ if (hasRemovedCycles) {
1274
+ return success ();
1275
+ }
1276
+
1277
+ return failure ();
1278
+ }
1279
+ };
1181
1280
} // namespace
1182
1281
1183
1282
void GenericOp::getCanonicalizationPatterns (RewritePatternSet &results,
1184
1283
MLIRContext *context) {
1185
- results
1186
- .add <DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
1187
- context);
1284
+ results.add <DeduplicateAndRemoveDeadOperandsAndResults,
1285
+ EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
1188
1286
}
1189
1287
1190
1288
LogicalResult GenericOp::fold (ArrayRef<Attribute>,
0 commit comments