@@ -834,19 +834,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834
834
// CollectMatchingOp
835
835
// ===----------------------------------------------------------------------===//
836
836
837
- // / Applies matcher operations from the given `block` assigning `op` as the
838
- // / payload of the block's first argument. Updates `state` accordingly. If any
839
- // / of the matcher produces a silenceable failure, discards it (printing the
840
- // / content to the debug output stream) and returns failure. If any of the
841
- // / matchers produces a definite failure, reports it and returns failure. If all
842
- // / matchers in the block succeed, populates `mappings` with the payload
843
- // / entities associated with the block terminator operands.
837
+ // / Applies matcher operations from the given `block` using
838
+ // / `blockArgumentMapping` to initialize block arguments. Updates `state`
839
+ // / accordingly. If any of the matcher produces a silenceable failure, discards
840
+ // / it (printing the content to the debug output stream) and returns failure. If
841
+ // / any of the matchers produces a definite failure, reports it and returns
842
+ // / failure. If all matchers in the block succeed, populates `mappings` with the
843
+ // / payload entities associated with the block terminator operands. Note that
844
+ // / `mappings` will be cleared before that.
844
845
static DiagnosedSilenceableFailure
845
- matchBlock (Block &block, Operation *op, transform::TransformState &state,
846
+ matchBlock (Block &block,
847
+ ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
848
+ transform::TransformState &state,
846
849
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847
850
assert (block.getParent () && " cannot match using a detached block" );
848
851
auto matchScope = state.make_region_scope (*block.getParent ());
849
- if (failed (state.mapBlockArgument (block.getArgument (0 ), {op})))
852
+ if (failed (
853
+ state.mapBlockArguments (block.getArguments (), blockArgumentMapping)))
850
854
return DiagnosedSilenceableFailure::definiteFailure ();
851
855
852
856
for (Operation &match : block.without_terminator ()) {
@@ -866,6 +870,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866
870
// Remember the values mapped to the terminator operands so we can
867
871
// forward them to the action.
868
872
ValueRange yieldedValues = block.getTerminator ()->getOperands ();
873
+ // Our contract with the caller is that the mappings will contain only the
874
+ // newly mapped values, clear the rest.
875
+ mappings.clear ();
869
876
transform::detail::prepareValueMappings (mappings, yieldedValues, state);
870
877
return DiagnosedSilenceableFailure::success ();
871
878
}
@@ -915,8 +922,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915
922
916
923
// Try matching.
917
924
SmallVector<SmallVector<MappedValue>> mappings;
918
- DiagnosedSilenceableFailure diag =
919
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
925
+ SmallVector<transform::MappedValue> inputMapping ({op});
926
+ DiagnosedSilenceableFailure diag = matchBlock (
927
+ matcher.getFunctionBody ().front (),
928
+ ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
929
+ mappings);
920
930
if (diag.isDefiniteFailure ())
921
931
return WalkResult::interrupt ();
922
932
if (diag.isSilenceableFailure ()) {
@@ -1001,6 +1011,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1001
1011
// ForeachMatchOp
1002
1012
// ===----------------------------------------------------------------------===//
1003
1013
1014
+ // This is fine because nothing is actually consumed by this op.
1015
+ bool transform::ForeachMatchOp::allowsRepeatedHandleOperands () { return true ; }
1016
+
1004
1017
DiagnosedSilenceableFailure
1005
1018
transform::ForeachMatchOp::apply (transform::TransformRewriter &rewriter,
1006
1019
transform::TransformResults &results,
@@ -1030,6 +1043,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1030
1043
1031
1044
DiagnosedSilenceableFailure overallDiag =
1032
1045
DiagnosedSilenceableFailure::success ();
1046
+
1047
+ SmallVector<SmallVector<MappedValue>> matchInputMapping;
1048
+ SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1049
+ SmallVector<SmallVector<MappedValue>> actionResultMapping;
1050
+ // Explicitly add the mapping for the first block argument (the op being
1051
+ // matched).
1052
+ matchInputMapping.emplace_back ();
1053
+ transform::detail::prepareValueMappings (matchInputMapping,
1054
+ getForwardedInputs (), state);
1055
+ SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front ();
1056
+ actionResultMapping.resize (getForwardedOutputs ().size ());
1057
+
1033
1058
for (Operation *root : state.getPayloadOps (getRoot ())) {
1034
1059
WalkResult walkResult = root->walk ([&](Operation *op) {
1035
1060
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1069,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1044
1069
llvm::dbgs () << " @" << op << " \n " ;
1045
1070
});
1046
1071
1072
+ firstMatchArgument.clear ();
1073
+ firstMatchArgument.push_back (op);
1074
+
1047
1075
// Try all the match/action pairs until the first successful match.
1048
1076
for (auto [matcher, action] : matchActionPairs) {
1049
- SmallVector<SmallVector<MappedValue>> mappings;
1050
1077
DiagnosedSilenceableFailure diag =
1051
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
1078
+ matchBlock (matcher.getFunctionBody ().front (), matchInputMapping,
1079
+ state, matchOutputMapping);
1052
1080
if (diag.isDefiniteFailure ())
1053
1081
return WalkResult::interrupt ();
1054
1082
if (diag.isSilenceableFailure ()) {
@@ -1058,10 +1086,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1058
1086
}
1059
1087
1060
1088
auto scope = state.make_region_scope (action.getFunctionBody ());
1061
- for ( auto &&[arg, map] : llvm::zip_equal (
1062
- action.getFunctionBody ().front ().getArguments (), mappings)) {
1063
- if ( failed (state. mapBlockArgument (arg, map )))
1064
- return WalkResult::interrupt ();
1089
+ if ( failed (state. mapBlockArguments (
1090
+ action.getFunctionBody ().front ().getArguments (),
1091
+ matchOutputMapping ))) {
1092
+ return WalkResult::interrupt ();
1065
1093
}
1066
1094
1067
1095
for (Operation &transform :
@@ -1082,6 +1110,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1082
1110
continue ;
1083
1111
}
1084
1112
}
1113
+ if (failed (detail::appendValueMappings (
1114
+ MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1115
+ action.getFunctionBody ().front ().getTerminator ()->getOperands (),
1116
+ state, getFlattenResults ()))) {
1117
+ emitDefiniteFailure ()
1118
+ << " action @" << action.getName ()
1119
+ << " has results associated with multiple payload entities, "
1120
+ " but flattening was not requested" ;
1121
+ return WalkResult::interrupt ();
1122
+ }
1085
1123
break ;
1086
1124
}
1087
1125
return WalkResult::advance ();
@@ -1096,6 +1134,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1096
1134
// by actions, are invalidated.
1097
1135
results.set (llvm::cast<OpResult>(getUpdated ()),
1098
1136
state.getPayloadOps (getRoot ()));
1137
+ for (auto &&[result, mapping] :
1138
+ llvm::zip_equal (getForwardedOutputs (), actionResultMapping)) {
1139
+ results.setMappedValues (result, mapping);
1140
+ }
1099
1141
return overallDiag;
1100
1142
}
1101
1143
@@ -1108,7 +1150,8 @@ void transform::ForeachMatchOp::getEffects(
1108
1150
}
1109
1151
1110
1152
consumesHandle (getRoot (), effects);
1111
- producesHandle (getUpdated (), effects);
1153
+ onlyReadsHandle (getForwardedInputs (), effects);
1154
+ producesHandle (getResults (), effects);
1112
1155
modifiesPayload (effects);
1113
1156
}
1114
1157
@@ -1224,6 +1267,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1224
1267
StringAttr::get (getContext (), TransformDialect::kArgConsumedAttrName );
1225
1268
for (auto &&[matcher, action] :
1226
1269
llvm::zip_equal (getMatchers (), getActions ())) {
1270
+ // Presence and typing.
1227
1271
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1228
1272
symbolTable.lookupNearestSymbolFrom (getOperation (),
1229
1273
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1294,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1250
1294
return failure ();
1251
1295
}
1252
1296
1253
- ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes ();
1254
- ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes ();
1297
+ // Input -> matcher forwarding.
1298
+ TypeRange operandTypes = getOperandTypes ();
1299
+ TypeRange matcherArguments = matcherSymbol.getArgumentTypes ();
1300
+ if (operandTypes.size () != matcherArguments.size ()) {
1301
+ InFlightDiagnostic diag =
1302
+ emitError () << " the number of operands (" << operandTypes.size ()
1303
+ << " ) doesn't match the number of matcher arguments ("
1304
+ << matcherArguments.size () << " ) for " << matcher;
1305
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1306
+ return diag;
1307
+ }
1308
+ for (auto &&[i, operand, argument] :
1309
+ llvm::enumerate (operandTypes, matcherArguments)) {
1310
+ if (matcherSymbol.getArgAttr (i, consumedAttr)) {
1311
+ InFlightDiagnostic diag =
1312
+ emitOpError ()
1313
+ << " does not expect matcher symbol to consume its operand #" << i;
1314
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1315
+ return diag;
1316
+ }
1317
+
1318
+ if (implementSameTransformInterface (operand, argument))
1319
+ continue ;
1320
+
1321
+ InFlightDiagnostic diag =
1322
+ emitError ()
1323
+ << " mismatching type interfaces for operand and matcher argument #"
1324
+ << i << " of matcher " << matcher;
1325
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1326
+ return diag;
1327
+ }
1328
+
1329
+ // Matcher -> action forwarding.
1330
+ TypeRange matcherResults = matcherSymbol.getResultTypes ();
1331
+ TypeRange actionArguments = actionSymbol.getArgumentTypes ();
1255
1332
if (matcherResults.size () != actionArguments.size ()) {
1256
1333
return emitError () << " mismatching number of matcher results and "
1257
1334
" action arguments between "
@@ -1265,31 +1342,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1265
1342
1266
1343
return emitError () << " mismatching type interfaces for matcher result "
1267
1344
" and action argument #"
1268
- << i;
1345
+ << i << " of matcher " << matcher << " and action "
1346
+ << action;
1269
1347
}
1270
1348
1271
- if (!actionSymbol.getResultTypes ().empty ()) {
1349
+ // Action -> result forwarding.
1350
+ TypeRange actionResults = actionSymbol.getResultTypes ();
1351
+ auto resultTypes = TypeRange (getResultTypes ()).drop_front ();
1352
+ if (actionResults.size () != resultTypes.size ()) {
1272
1353
InFlightDiagnostic diag =
1273
- emitError () << " action symbol is not expected to have results" ;
1354
+ emitError () << " the number of action results ("
1355
+ << actionResults.size () << " ) for " << action
1356
+ << " doesn't match the number of extra op results ("
1357
+ << resultTypes.size () << " )" ;
1274
1358
diag.attachNote (actionSymbol->getLoc ()) << " symbol declaration" ;
1275
1359
return diag;
1276
1360
}
1361
+ for (auto &&[i, resultType, actionType] :
1362
+ llvm::enumerate (resultTypes, actionResults)) {
1363
+ if (implementSameTransformInterface (resultType, actionType))
1364
+ continue ;
1277
1365
1278
- if (matcherSymbol.getArgumentTypes ().size () != 1 ||
1279
- !implementSameTransformInterface (matcherSymbol.getArgumentTypes ()[0 ],
1280
- getRoot ().getType ())) {
1281
- InFlightDiagnostic diag =
1282
- emitOpError () << " expects matcher symbol to have one argument with "
1283
- " the same transform interface as the first operand" ;
1284
- diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1285
- return diag;
1286
- }
1287
-
1288
- if (matcherSymbol.getArgAttr (0 , consumedAttr)) {
1289
1366
InFlightDiagnostic diag =
1290
- emitOpError ()
1291
- << " does not expect matcher symbol to consume its operand " ;
1292
- diag.attachNote (matcherSymbol ->getLoc ()) << " symbol declaration" ;
1367
+ emitError () << " mismatching type interfaces for action result # " << i
1368
+ << " of action " << action << " and op result " ;
1369
+ diag.attachNote (actionSymbol ->getLoc ()) << " symbol declaration" ;
1293
1370
return diag;
1294
1371
}
1295
1372
}
0 commit comments