19
19
#include " mlir/IR/BuiltinAttributes.h"
20
20
#include " mlir/IR/Diagnostics.h"
21
21
#include " mlir/IR/Dominance.h"
22
+ #include " mlir/IR/OpImplementation.h"
22
23
#include " mlir/IR/OperationSupport.h"
23
24
#include " mlir/IR/PatternMatch.h"
24
25
#include " mlir/IR/Verifier.h"
@@ -834,19 +835,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834
835
// CollectMatchingOp
835
836
// ===----------------------------------------------------------------------===//
836
837
837
- // / Applies matcher operations from the given `block` assigning `op` as the
838
- // / payload of the block's first argument. Updates `state` accordingly. If any
839
- // / of the matcher produces a silenceable failure, discards it (printing the
840
- // / content to the debug output stream) and returns failure. If any of the
841
- // / matchers produces a definite failure, reports it and returns failure. If all
842
- // / matchers in the block succeed, populates `mappings` with the payload
843
- // / entities associated with the block terminator operands.
838
+ // / Applies matcher operations from the given `block` using
839
+ // / `blockArgumentMapping` to initialize block arguments. Updates `state`
840
+ // / accordingly. If any of the matcher produces a silenceable failure, discards
841
+ // / it (printing the content to the debug output stream) and returns failure. If
842
+ // / any of the matchers produces a definite failure, reports it and returns
843
+ // / failure. If all matchers in the block succeed, populates `mappings` with the
844
+ // / payload entities associated with the block terminator operands. Note that
845
+ // / `mappings` will be cleared before that.
844
846
static DiagnosedSilenceableFailure
845
- matchBlock (Block &block, Operation *op, transform::TransformState &state,
847
+ matchBlock (Block &block,
848
+ ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
849
+ transform::TransformState &state,
846
850
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847
851
assert (block.getParent () && " cannot match using a detached block" );
848
852
auto matchScope = state.make_region_scope (*block.getParent ());
849
- if (failed (state.mapBlockArgument (block.getArgument (0 ), {op})))
853
+ if (failed (
854
+ state.mapBlockArguments (block.getArguments (), blockArgumentMapping)))
850
855
return DiagnosedSilenceableFailure::definiteFailure ();
851
856
852
857
for (Operation &match : block.without_terminator ()) {
@@ -866,6 +871,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866
871
// Remember the values mapped to the terminator operands so we can
867
872
// forward them to the action.
868
873
ValueRange yieldedValues = block.getTerminator ()->getOperands ();
874
+ // Our contract with the caller is that the mappings will contain only the
875
+ // newly mapped values, clear the rest.
876
+ mappings.clear ();
869
877
transform::detail::prepareValueMappings (mappings, yieldedValues, state);
870
878
return DiagnosedSilenceableFailure::success ();
871
879
}
@@ -915,8 +923,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915
923
916
924
// Try matching.
917
925
SmallVector<SmallVector<MappedValue>> mappings;
918
- DiagnosedSilenceableFailure diag =
919
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
926
+ SmallVector<transform::MappedValue> inputMapping ({op});
927
+ DiagnosedSilenceableFailure diag = matchBlock (
928
+ matcher.getFunctionBody ().front (),
929
+ ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
930
+ mappings);
920
931
if (diag.isDefiniteFailure ())
921
932
return WalkResult::interrupt ();
922
933
if (diag.isSilenceableFailure ()) {
@@ -1001,6 +1012,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1001
1012
// ForeachMatchOp
1002
1013
// ===----------------------------------------------------------------------===//
1003
1014
1015
+ // This is fine because nothing is actually consumed by this op.
1016
+ bool transform::ForeachMatchOp::allowsRepeatedHandleOperands () { return true ; }
1017
+
1004
1018
DiagnosedSilenceableFailure
1005
1019
transform::ForeachMatchOp::apply (transform::TransformRewriter &rewriter,
1006
1020
transform::TransformResults &results,
@@ -1030,6 +1044,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1030
1044
1031
1045
DiagnosedSilenceableFailure overallDiag =
1032
1046
DiagnosedSilenceableFailure::success ();
1047
+
1048
+ SmallVector<SmallVector<MappedValue>> matchInputMapping;
1049
+ SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1050
+ SmallVector<SmallVector<MappedValue>> actionResultMapping;
1051
+ // Explicitly add the mapping for the first block argument (the op being
1052
+ // matched).
1053
+ matchInputMapping.emplace_back ();
1054
+ transform::detail::prepareValueMappings (matchInputMapping,
1055
+ getForwardedInputs (), state);
1056
+ SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front ();
1057
+ actionResultMapping.resize (getForwardedOutputs ().size ());
1058
+
1033
1059
for (Operation *root : state.getPayloadOps (getRoot ())) {
1034
1060
WalkResult walkResult = root->walk ([&](Operation *op) {
1035
1061
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1070,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1044
1070
llvm::dbgs () << " @" << op << " \n " ;
1045
1071
});
1046
1072
1073
+ firstMatchArgument.clear ();
1074
+ firstMatchArgument.push_back (op);
1075
+
1047
1076
// Try all the match/action pairs until the first successful match.
1048
1077
for (auto [matcher, action] : matchActionPairs) {
1049
- SmallVector<SmallVector<MappedValue>> mappings;
1050
1078
DiagnosedSilenceableFailure diag =
1051
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
1079
+ matchBlock (matcher.getFunctionBody ().front (), matchInputMapping,
1080
+ state, matchOutputMapping);
1052
1081
if (diag.isDefiniteFailure ())
1053
1082
return WalkResult::interrupt ();
1054
1083
if (diag.isSilenceableFailure ()) {
@@ -1058,10 +1087,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1058
1087
}
1059
1088
1060
1089
auto scope = state.make_region_scope (action.getFunctionBody ());
1061
- for ( auto &&[arg, map] : llvm::zip_equal (
1062
- action.getFunctionBody ().front ().getArguments (), mappings)) {
1063
- if ( failed (state. mapBlockArgument (arg, map )))
1064
- return WalkResult::interrupt ();
1090
+ if ( failed (state. mapBlockArguments (
1091
+ action.getFunctionBody ().front ().getArguments (),
1092
+ matchOutputMapping ))) {
1093
+ return WalkResult::interrupt ();
1065
1094
}
1066
1095
1067
1096
for (Operation &transform :
@@ -1082,6 +1111,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1082
1111
continue ;
1083
1112
}
1084
1113
}
1114
+ if (failed (detail::appendValueMappings (
1115
+ MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1116
+ action.getFunctionBody ().front ().getTerminator ()->getOperands (),
1117
+ state, getFlattenResults ()))) {
1118
+ emitDefiniteFailure ()
1119
+ << " action @" << action.getName ()
1120
+ << " has results associated with multiple payload entities, "
1121
+ " but flattening was not requested" ;
1122
+ return WalkResult::interrupt ();
1123
+ }
1085
1124
break ;
1086
1125
}
1087
1126
return WalkResult::advance ();
@@ -1096,9 +1135,21 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1096
1135
// by actions, are invalidated.
1097
1136
results.set (llvm::cast<OpResult>(getUpdated ()),
1098
1137
state.getPayloadOps (getRoot ()));
1138
+ for (auto &&[result, mapping] :
1139
+ llvm::zip_equal (getForwardedOutputs (), actionResultMapping)) {
1140
+ results.setMappedValues (result, mapping);
1141
+ }
1099
1142
return overallDiag;
1100
1143
}
1101
1144
1145
+ void transform::ForeachMatchOp::getAsmResultNames (
1146
+ OpAsmSetValueNameFn setNameFn) {
1147
+ setNameFn (getUpdated (), " updated_root" );
1148
+ for (Value v : getForwardedOutputs ()) {
1149
+ setNameFn (v, " yielded" );
1150
+ }
1151
+ }
1152
+
1102
1153
void transform::ForeachMatchOp::getEffects (
1103
1154
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1104
1155
// Bail if invalid.
@@ -1108,7 +1159,8 @@ void transform::ForeachMatchOp::getEffects(
1108
1159
}
1109
1160
1110
1161
consumesHandle (getRoot (), effects);
1111
- producesHandle (getUpdated (), effects);
1162
+ onlyReadsHandle (getForwardedInputs (), effects);
1163
+ producesHandle (getResults (), effects);
1112
1164
modifiesPayload (effects);
1113
1165
}
1114
1166
@@ -1224,6 +1276,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1224
1276
StringAttr::get (getContext (), TransformDialect::kArgConsumedAttrName );
1225
1277
for (auto &&[matcher, action] :
1226
1278
llvm::zip_equal (getMatchers (), getActions ())) {
1279
+ // Presence and typing.
1227
1280
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1228
1281
symbolTable.lookupNearestSymbolFrom (getOperation (),
1229
1282
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1303,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1250
1303
return failure ();
1251
1304
}
1252
1305
1253
- ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes ();
1254
- ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes ();
1306
+ // Input -> matcher forwarding.
1307
+ TypeRange operandTypes = getOperandTypes ();
1308
+ TypeRange matcherArguments = matcherSymbol.getArgumentTypes ();
1309
+ if (operandTypes.size () != matcherArguments.size ()) {
1310
+ InFlightDiagnostic diag =
1311
+ emitError () << " the number of operands (" << operandTypes.size ()
1312
+ << " ) doesn't match the number of matcher arguments ("
1313
+ << matcherArguments.size () << " ) for " << matcher;
1314
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1315
+ return diag;
1316
+ }
1317
+ for (auto &&[i, operand, argument] :
1318
+ llvm::enumerate (operandTypes, matcherArguments)) {
1319
+ if (matcherSymbol.getArgAttr (i, consumedAttr)) {
1320
+ InFlightDiagnostic diag =
1321
+ emitOpError ()
1322
+ << " does not expect matcher symbol to consume its operand #" << i;
1323
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1324
+ return diag;
1325
+ }
1326
+
1327
+ if (implementSameTransformInterface (operand, argument))
1328
+ continue ;
1329
+
1330
+ InFlightDiagnostic diag =
1331
+ emitError ()
1332
+ << " mismatching type interfaces for operand and matcher argument #"
1333
+ << i << " of matcher " << matcher;
1334
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1335
+ return diag;
1336
+ }
1337
+
1338
+ // Matcher -> action forwarding.
1339
+ TypeRange matcherResults = matcherSymbol.getResultTypes ();
1340
+ TypeRange actionArguments = actionSymbol.getArgumentTypes ();
1255
1341
if (matcherResults.size () != actionArguments.size ()) {
1256
1342
return emitError () << " mismatching number of matcher results and "
1257
1343
" action arguments between "
@@ -1265,31 +1351,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1265
1351
1266
1352
return emitError () << " mismatching type interfaces for matcher result "
1267
1353
" and action argument #"
1268
- << i;
1354
+ << i << " of matcher " << matcher << " and action "
1355
+ << action;
1269
1356
}
1270
1357
1271
- if (!actionSymbol.getResultTypes ().empty ()) {
1358
+ // Action -> result forwarding.
1359
+ TypeRange actionResults = actionSymbol.getResultTypes ();
1360
+ auto resultTypes = TypeRange (getResultTypes ()).drop_front ();
1361
+ if (actionResults.size () != resultTypes.size ()) {
1272
1362
InFlightDiagnostic diag =
1273
- emitError () << " action symbol is not expected to have results" ;
1363
+ emitError () << " the number of action results ("
1364
+ << actionResults.size () << " ) for " << action
1365
+ << " doesn't match the number of extra op results ("
1366
+ << resultTypes.size () << " )" ;
1274
1367
diag.attachNote (actionSymbol->getLoc ()) << " symbol declaration" ;
1275
1368
return diag;
1276
1369
}
1370
+ for (auto &&[i, resultType, actionType] :
1371
+ llvm::enumerate (resultTypes, actionResults)) {
1372
+ if (implementSameTransformInterface (resultType, actionType))
1373
+ continue ;
1277
1374
1278
- if (matcherSymbol.getArgumentTypes ().size () != 1 ||
1279
- !implementSameTransformInterface (matcherSymbol.getArgumentTypes ()[0 ],
1280
- getRoot ().getType ())) {
1281
- InFlightDiagnostic diag =
1282
- emitOpError () << " expects matcher symbol to have one argument with "
1283
- " the same transform interface as the first operand" ;
1284
- diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1285
- return diag;
1286
- }
1287
-
1288
- if (matcherSymbol.getArgAttr (0 , consumedAttr)) {
1289
1375
InFlightDiagnostic diag =
1290
- emitOpError ()
1291
- << " does not expect matcher symbol to consume its operand " ;
1292
- diag.attachNote (matcherSymbol ->getLoc ()) << " symbol declaration" ;
1376
+ emitError () << " mismatching type interfaces for action result # " << i
1377
+ << " of action " << action << " and op result " ;
1378
+ diag.attachNote (actionSymbol ->getLoc ()) << " symbol declaration" ;
1293
1379
return diag;
1294
1380
}
1295
1381
}
0 commit comments