@@ -871,285 +871,10 @@ void GenericOp::getEffects(
871
871
getDpsInputOperands (), getDpsInitOperands ());
872
872
}
873
873
874
- static bool isResultValueDead (linalg::GenericOp genericOp, OpResult result) {
875
- if (!result.use_empty ())
876
- return false ;
877
- // If out operand not used in payload, we can drop it.
878
- OpOperand *outputOpOperand =
879
- genericOp.getDpsInitOperand (result.getResultNumber ());
880
- if (!genericOp.payloadUsesValueFromOperand (outputOpOperand))
881
- return true ;
882
-
883
- // The out operand that is part of a payload can be dropped if
884
- // these conditions are met:
885
- // - Result from out operand is dead.
886
- // - User of arg is yield.
887
- // - outArg data is not being used by other outArgs.
888
-
889
- // Check block arg and cycle from out operand has a single use.
890
- BlockArgument outputArg =
891
- genericOp.getRegionOutputArgs ()[result.getResultNumber ()];
892
- if (!outputArg.hasOneUse ())
893
- return false ;
894
- Operation *argUserOp = *outputArg.user_begin ();
895
-
896
- // Check argUser has no other use.
897
- if (!argUserOp->use_empty ())
898
- return false ;
899
-
900
- // Check that argUser is a yield.
901
- auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
902
- if (!yieldOp)
903
- return false ;
904
-
905
- // Check outArg data is not being used by other outArgs.
906
- if (yieldOp.getOperand (result.getResultNumber ()) != outputArg)
907
- return false ;
908
-
909
- return true ;
910
- }
911
-
912
874
LogicalResult GenericOp::verify () { return success (); }
913
875
914
876
namespace {
915
877
916
- struct DeduplicateAndRemoveDeadOperandsAndResults
917
- : public OpRewritePattern<GenericOp> {
918
- using OpRewritePattern<GenericOp>::OpRewritePattern;
919
-
920
- LogicalResult matchAndRewrite (GenericOp genericOp,
921
- PatternRewriter &rewriter) const override {
922
- // Create a map from argument position in the original op to the argument
923
- // position in the new op. If the argument is dropped it wont have an entry.
924
- SmallVector<OpOperand *> droppedOpOperands;
925
-
926
- // Information needed to build the new op.
927
- SmallVector<Value> newInputOperands, newOutputOperands;
928
- SmallVector<AffineMap> newIndexingMaps;
929
-
930
- // Gather information about duplicate input operands.
931
- llvm::SmallDenseMap<unsigned , unsigned > origInsToNewInsPos =
932
- deduplicateInputOperands (genericOp, droppedOpOperands, newInputOperands,
933
- newIndexingMaps);
934
-
935
- // Gather information about the dropped outputs.
936
- llvm::SmallDenseMap<unsigned , unsigned > origOutsToNewOutsPos =
937
- deduplicateOutputOperands (genericOp, droppedOpOperands,
938
- newOutputOperands, newIndexingMaps);
939
-
940
- // Check if there is any change to operands.
941
- if (newInputOperands.size () + newOutputOperands.size () ==
942
- genericOp->getNumOperands ())
943
- return failure ();
944
-
945
- // Create the new op with the body being empty.
946
- Location loc = genericOp.getLoc ();
947
- SmallVector<Type> newResultTypes;
948
- for (Value v : newOutputOperands)
949
- if (v.getType ().isa <TensorType>())
950
- newResultTypes.push_back (v.getType ());
951
- auto newOp = rewriter.create <GenericOp>(
952
- loc, newResultTypes, newInputOperands, newOutputOperands,
953
- rewriter.getAffineMapArrayAttr (newIndexingMaps),
954
- genericOp.getIteratorTypes (), genericOp.getDocAttr (),
955
- genericOp.getLibraryCallAttr (),
956
- [](OpBuilder & /* builder*/ , Location /* loc*/ , ValueRange /* args*/ ) {
957
- return ;
958
- });
959
- // Copy over unknown attributes. They might be load bearing for some flow.
960
- ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames ();
961
- for (NamedAttribute kv : genericOp->getAttrs ())
962
- if (!llvm::is_contained (odsAttrs, kv.getName ().getValue ()))
963
- newOp->setAttr (kv.getName (), kv.getValue ());
964
-
965
- // Fix up the payload of the canonicalized operation.
966
- populateOpPayload (genericOp, newOp, origInsToNewInsPos,
967
- origOutsToNewOutsPos, rewriter);
968
-
969
- // Replace all live uses of the op.
970
- SmallVector<Value> replacementsVals (genericOp->getNumResults (), nullptr );
971
- for (const auto &result : llvm::enumerate (genericOp.getResults ())) {
972
- auto it = origOutsToNewOutsPos.find (result.index ());
973
- if (it == origOutsToNewOutsPos.end ())
974
- continue ;
975
- replacementsVals[result.index ()] = newOp.getResult (it->second );
976
- }
977
- rewriter.replaceOp (genericOp, replacementsVals);
978
- return success ();
979
- }
980
-
981
- private:
982
- // Deduplicate input operands, and return the
983
- // - Mapping from operand position in the original op, to operand position in
984
- // the canonicalized op.
985
- // - The preserved input operands list (by reference).
986
- llvm::SmallDenseMap<unsigned , unsigned >
987
- deduplicateInputOperands (GenericOp genericOp,
988
- SmallVector<OpOperand *> &droppedOpOperands,
989
- SmallVector<Value> &newInputOperands,
990
- SmallVector<AffineMap> &newIndexingMaps) const {
991
- llvm::SmallDenseMap<unsigned , unsigned > origToNewPos;
992
- llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned > dedupedInputs;
993
- for (const auto &en : llvm::enumerate (genericOp.getDpsInputOperands ())) {
994
- OpOperand *inputOpOperand = en.value ();
995
- // Check if operand is dead and if dropping the indexing map makes the
996
- // loops to shape computation invalid.
997
- if (!genericOp.payloadUsesValueFromOperand (inputOpOperand)) {
998
- // Add the current operands to the list of potentially droppable
999
- // operands. If it cannot be dropped, this needs to be popped back.
1000
- droppedOpOperands.push_back (inputOpOperand);
1001
- if (genericOp.canOpOperandsBeDropped (droppedOpOperands))
1002
- continue ;
1003
- droppedOpOperands.pop_back ();
1004
- }
1005
-
1006
- // Check if this operand is a duplicate.
1007
- AffineMap indexingMap = genericOp.getMatchingIndexingMap (inputOpOperand);
1008
- auto it = dedupedInputs.find (
1009
- std::make_pair (inputOpOperand->get (), indexingMap));
1010
- if (it != dedupedInputs.end ()) {
1011
- origToNewPos[en.index ()] = it->second ;
1012
- droppedOpOperands.push_back (inputOpOperand);
1013
- continue ;
1014
- }
1015
-
1016
- // This is a preserved argument.
1017
- origToNewPos[en.index ()] = newInputOperands.size ();
1018
- dedupedInputs[{inputOpOperand->get (), indexingMap}] =
1019
- newInputOperands.size ();
1020
- newInputOperands.push_back (inputOpOperand->get ());
1021
- newIndexingMaps.push_back (indexingMap);
1022
- }
1023
- return origToNewPos;
1024
- }
1025
-
1026
- // Deduplicate output operands, and return the
1027
- // - Mapping from operand position in the original op, to operand position in
1028
- // the canonicalized op.
1029
- // - The preserved output operands list (by reference).
1030
- llvm::SmallDenseMap<unsigned , unsigned >
1031
- deduplicateOutputOperands (GenericOp genericOp,
1032
- SmallVector<OpOperand *> &droppedOpOperands,
1033
- SmallVector<Value> &newOutputOperands,
1034
- SmallVector<AffineMap> &newIndexingMaps) const {
1035
- llvm::SmallDenseMap<unsigned , unsigned > origToNewPos;
1036
- llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned >
1037
- dedupedOutpts;
1038
- // If the op doesnt have tensor semantics, keep all the outputs as
1039
- // preserved.
1040
- if (!genericOp.hasTensorSemantics ()) {
1041
- for (const auto &en : llvm::enumerate (genericOp.getDpsInitOperands ())) {
1042
- origToNewPos[en.index ()] = newOutputOperands.size ();
1043
- newOutputOperands.push_back (en.value ()->get ());
1044
- newIndexingMaps.push_back (genericOp.getMatchingIndexingMap (en.value ()));
1045
- }
1046
- return origToNewPos;
1047
- }
1048
- // Output argument can be dropped if the result has
1049
- // - no users, and
1050
- // - it is not used in the payload, and
1051
- // - the corresponding indexing maps are not needed for loop bound
1052
- // computation.
1053
- auto yieldOp = cast<YieldOp>(genericOp.getBody ()->getTerminator ());
1054
- for (const auto &outputOpOperand :
1055
- llvm::enumerate (genericOp.getDpsInitOperands ())) {
1056
- OpResult result = genericOp.getTiedOpResult (outputOpOperand.value ());
1057
- AffineMap indexingMap =
1058
- genericOp.getMatchingIndexingMap (outputOpOperand.value ());
1059
- auto key = std::make_tuple (outputOpOperand.value ()->get (), indexingMap,
1060
- yieldOp->getOperand (outputOpOperand.index ()));
1061
- if (isResultValueDead (genericOp, result)) {
1062
- // Check if the opoperand can be dropped without affecting loop
1063
- // bound computation. Add the operand to the list of dropped op
1064
- // operand for checking. If it cannot be dropped, need to pop the
1065
- // value back.
1066
- droppedOpOperands.push_back (outputOpOperand.value ());
1067
- if (genericOp.canOpOperandsBeDropped (droppedOpOperands)) {
1068
- continue ;
1069
- }
1070
- droppedOpOperands.pop_back ();
1071
- }
1072
-
1073
- if (!genericOp.payloadUsesValueFromOperand (outputOpOperand.value ())) {
1074
- // The out operand can also be dropped if it is computed redundantly
1075
- // by another result, the conditions for that are
1076
- // - The same operand is used as the out operand
1077
- // - The same indexing map is used
1078
- // - The same yield value is used.
1079
- auto it = dedupedOutpts.find (key);
1080
- if (it != dedupedOutpts.end ()) {
1081
- origToNewPos[outputOpOperand.index ()] = it->second ;
1082
- droppedOpOperands.push_back (outputOpOperand.value ());
1083
- continue ;
1084
- }
1085
- }
1086
-
1087
- origToNewPos[outputOpOperand.index ()] = newOutputOperands.size ();
1088
- dedupedOutpts[key] = newOutputOperands.size ();
1089
- newOutputOperands.push_back (outputOpOperand.value ()->get ());
1090
- newIndexingMaps.push_back (
1091
- genericOp.getMatchingIndexingMap (outputOpOperand.value ()));
1092
- }
1093
- return origToNewPos;
1094
- }
1095
-
1096
- // Populate the body of the canonicalized operation.
1097
- void populateOpPayload (
1098
- GenericOp genericOp, GenericOp newOp,
1099
- const llvm::SmallDenseMap<unsigned , unsigned > &origInsToNewInsPos,
1100
- const llvm::SmallDenseMap<unsigned , unsigned > &origOutsToNewOutsPos,
1101
- PatternRewriter &rewriter) const {
1102
- // Merge the body of the original op with the new op.
1103
- Block *newOpBlock = &newOp.getRegion ().front ();
1104
- assert (newOpBlock->empty () && " expected new op to have an empty payload" );
1105
- Block *origOpBlock = &genericOp.getRegion ().front ();
1106
- SmallVector<Value> replacements (origOpBlock->getNumArguments (), nullptr );
1107
-
1108
- // Replace all arguments in the original op, with arguments from the
1109
- // canonicalized op.
1110
- auto updateReplacements =
1111
- [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
1112
- const llvm::SmallDenseMap<unsigned , unsigned > &map) {
1113
- for (const auto &origOperand : llvm::enumerate (origOperands)) {
1114
- auto it = map.find (origOperand.index ());
1115
- if (it == map.end ())
1116
- continue ;
1117
- OpOperand *newOperand = newOperands[it->second ];
1118
- replacements[origOperand.value ()->getOperandNumber ()] =
1119
- newOpBlock->getArgument (newOperand->getOperandNumber ());
1120
- }
1121
- };
1122
-
1123
- OpOperandVector origInputOperands = genericOp.getDpsInputOperands ();
1124
- OpOperandVector newInputOperands = newOp.getDpsInputOperands ();
1125
- updateReplacements (origInputOperands, newInputOperands, origInsToNewInsPos);
1126
-
1127
- OpOperandVector origOutputOperands = genericOp.getDpsInitOperands ();
1128
- OpOperandVector newOutputOperands = newOp.getDpsInitOperands ();
1129
- updateReplacements (origOutputOperands, newOutputOperands,
1130
- origOutsToNewOutsPos);
1131
-
1132
- // Drop the unused yield args.
1133
- if (newOp.getNumDpsInits () != genericOp.getNumDpsInits ()) {
1134
- OpBuilder::InsertionGuard g (rewriter);
1135
- YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator ());
1136
- rewriter.setInsertionPoint (origYieldOp);
1137
-
1138
- SmallVector<Value> newYieldVals (newOp.getNumDpsInits (), nullptr );
1139
- for (const auto &yieldOpOperands :
1140
- llvm::enumerate (origYieldOp.getValues ())) {
1141
- auto it = origOutsToNewOutsPos.find (yieldOpOperands.index ());
1142
- if (it == origOutsToNewOutsPos.end ())
1143
- continue ;
1144
- newYieldVals[it->second ] = yieldOpOperands.value ();
1145
- }
1146
- rewriter.replaceOpWithNewOp <YieldOp>(origYieldOp, newYieldVals);
1147
- }
1148
-
1149
- rewriter.mergeBlocks (origOpBlock, newOpBlock, replacements);
1150
- }
1151
- };
1152
-
1153
878
// / Remove generic operations (on tensors) that are just copying
1154
879
// / the values from inputs to the results. Requirements are
1155
880
// / 1) All iterator types are parallel
@@ -1227,74 +952,11 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1227
952
}
1228
953
};
1229
954
1230
- // / Remove unused cycles.
1231
- // / We can remove unused cycle within a payload of generic region
1232
- // / if these conditions are met:
1233
- // / - Result from out operand is dead.
1234
- // / - Block arg from out operand has a single use in the %cycle
1235
- // / instruction.
1236
- // / - Cycle has a single use and it is in yield.
1237
- struct RemoveUnusedCycleInGenericOp : public OpRewritePattern <GenericOp> {
1238
- using OpRewritePattern<GenericOp>::OpRewritePattern;
1239
-
1240
- LogicalResult matchAndRewrite (GenericOp genericOp,
1241
- PatternRewriter &rewriter) const override {
1242
-
1243
- // If the op doesnt have tensor semantics, preserve the outputs as is.
1244
- if (!genericOp.hasTensorSemantics ())
1245
- return failure ();
1246
-
1247
- bool hasRemovedCycles = false ;
1248
- // Iterate over output operands and remove any unused cycles.
1249
- for (const auto &outputOpOperand :
1250
- llvm::enumerate (genericOp.getDpsInitOperands ())) {
1251
-
1252
- // Check that result from out operand is dead.
1253
- Value result = genericOp.getResult (outputOpOperand.index ());
1254
- if (!result.use_empty ())
1255
- continue ;
1256
-
1257
- // Check that outputArg has one use in cycle.
1258
- BlockArgument outputArg =
1259
- genericOp.getRegionOutputArgs ()[outputOpOperand.index ()];
1260
- if (!outputArg.hasOneUse ())
1261
- continue ;
1262
-
1263
- // Check cycle has at most one use.
1264
- Operation *cycleOp = *outputArg.user_begin ();
1265
- if (!cycleOp->hasOneUse ())
1266
- continue ;
1267
-
1268
- // Check that the cycleUser is a yield.
1269
- Operation *cycleUserOp = *cycleOp->user_begin ();
1270
- if (!isa<linalg::YieldOp>(cycleUserOp))
1271
- continue ;
1272
-
1273
- // Check that argIndex matches yieldIndex, else data is being used.
1274
- if (cycleUserOp->getOperand (outputOpOperand.index ()) !=
1275
- cycleOp->getResult (0 ))
1276
- continue ;
1277
-
1278
- // Directly replace the cycle with the blockArg such that
1279
- // Deduplicate pattern can eliminate it along with unused yield.
1280
- rewriter.replaceOp (cycleOp, outputArg);
1281
- rewriter.updateRootInPlace (genericOp, [] {});
1282
- hasRemovedCycles = true ;
1283
- }
1284
-
1285
- if (hasRemovedCycles) {
1286
- return success ();
1287
- }
1288
-
1289
- return failure ();
1290
- }
1291
- };
1292
955
} // namespace
1293
956
1294
957
void GenericOp::getCanonicalizationPatterns (RewritePatternSet &results,
1295
958
MLIRContext *context) {
1296
- results.add <DeduplicateAndRemoveDeadOperandsAndResults,
1297
- EraseIdentityGenericOp, RemoveUnusedCycleInGenericOp>(context);
959
+ results.add <EraseIdentityGenericOp>(context);
1298
960
}
1299
961
1300
962
LogicalResult GenericOp::fold (ArrayRef<Attribute>,
0 commit comments