@@ -834,19 +834,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834
834
// CollectMatchingOp
835
835
// ===----------------------------------------------------------------------===//
836
836
837
- // / Applies matcher operations from the given `block` assigning `op` as the
838
- // / payload of the block's first argument. Updates `state` accordingly. If any
839
- // / of the matcher produces a silenceable failure, discards it (printing the
840
- // / content to the debug output stream) and returns failure. If any of the
841
- // / matchers produces a definite failure, reports it and returns failure. If all
842
- // / matchers in the block succeed, populates `mappings` with the payload
843
- // / entities associated with the block terminator operands.
837
+ // / Applies matcher operations from the given `block` using
838
+ // / `blockArgumentMapping` to initialize block arguments. Updates `state`
839
+ // / accordingly. If any of the matcher produces a silenceable failure, discards
840
+ // / it (printing the content to the debug output stream) and returns failure. If
841
+ // / any of the matchers produces a definite failure, reports it and returns
842
+ // / failure. If all matchers in the block succeed, populates `mappings` with the
843
+ // / payload entities associated with the block terminator operands. Note that
844
+ // / `mappings` will be cleared before that.
844
845
static DiagnosedSilenceableFailure
845
- matchBlock (Block &block, Operation *op, transform::TransformState &state,
846
+ matchBlock (Block &block,
847
+ ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
848
+ transform::TransformState &state,
846
849
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847
850
assert (block.getParent () && " cannot match using a detached block" );
848
851
auto matchScope = state.make_region_scope (*block.getParent ());
849
- if (failed (state.mapBlockArgument (block.getArgument (0 ), {op})))
852
+ if (failed (
853
+ state.mapBlockArguments (block.getArguments (), blockArgumentMapping)))
850
854
return DiagnosedSilenceableFailure::definiteFailure ();
851
855
852
856
for (Operation &match : block.without_terminator ()) {
@@ -866,6 +870,7 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866
870
// Remember the values mapped to the terminator operands so we can
867
871
// forward them to the action.
868
872
ValueRange yieldedValues = block.getTerminator ()->getOperands ();
873
+ mappings.clear ();
869
874
transform::detail::prepareValueMappings (mappings, yieldedValues, state);
870
875
return DiagnosedSilenceableFailure::success ();
871
876
}
@@ -915,8 +920,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915
920
916
921
// Try matching.
917
922
SmallVector<SmallVector<MappedValue>> mappings;
918
- DiagnosedSilenceableFailure diag =
919
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
923
+ SmallVector<transform::MappedValue> inputMapping ({op});
924
+ DiagnosedSilenceableFailure diag = matchBlock (
925
+ matcher.getFunctionBody ().front (),
926
+ ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
927
+ mappings);
920
928
if (diag.isDefiniteFailure ())
921
929
return WalkResult::interrupt ();
922
930
if (diag.isSilenceableFailure ()) {
@@ -1001,6 +1009,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1001
1009
// ForeachMatchOp
1002
1010
// ===----------------------------------------------------------------------===//
1003
1011
1012
+ // This is fine because nothing is actually consumed by this op.
1013
+ bool transform::ForeachMatchOp::allowsRepeatedHandleOperands () { return true ; }
1014
+
1004
1015
DiagnosedSilenceableFailure
1005
1016
transform::ForeachMatchOp::apply (transform::TransformRewriter &rewriter,
1006
1017
transform::TransformResults &results,
@@ -1030,6 +1041,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1030
1041
1031
1042
DiagnosedSilenceableFailure overallDiag =
1032
1043
DiagnosedSilenceableFailure::success ();
1044
+
1045
+ SmallVector<SmallVector<MappedValue>> matchInputMapping;
1046
+ SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1047
+ SmallVector<SmallVector<MappedValue>> actionResultMapping;
1048
+ matchInputMapping.emplace_back ();
1049
+ transform::detail::prepareValueMappings (matchInputMapping,
1050
+ getForwardedInputs (), state);
1051
+ SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front ();
1052
+ actionResultMapping.resize (getForwardedOutputs ().size ());
1053
+
1033
1054
for (Operation *root : state.getPayloadOps (getRoot ())) {
1034
1055
WalkResult walkResult = root->walk ([&](Operation *op) {
1035
1056
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1065,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1044
1065
llvm::dbgs () << " @" << op << " \n " ;
1045
1066
});
1046
1067
1068
+ firstMatchArgument.clear ();
1069
+ firstMatchArgument.push_back (op);
1070
+
1047
1071
// Try all the match/action pairs until the first successful match.
1048
1072
for (auto [matcher, action] : matchActionPairs) {
1049
- SmallVector<SmallVector<MappedValue>> mappings;
1050
1073
DiagnosedSilenceableFailure diag =
1051
- matchBlock (matcher.getFunctionBody ().front (), op, state, mappings);
1074
+ matchBlock (matcher.getFunctionBody ().front (), matchInputMapping,
1075
+ state, matchOutputMapping);
1052
1076
if (diag.isDefiniteFailure ())
1053
1077
return WalkResult::interrupt ();
1054
1078
if (diag.isSilenceableFailure ()) {
@@ -1058,10 +1082,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1058
1082
}
1059
1083
1060
1084
auto scope = state.make_region_scope (action.getFunctionBody ());
1061
- for ( auto &&[arg, map] : llvm::zip_equal (
1062
- action.getFunctionBody ().front ().getArguments (), mappings)) {
1063
- if ( failed (state. mapBlockArgument (arg, map )))
1064
- return WalkResult::interrupt ();
1085
+ if ( failed (state. mapBlockArguments (
1086
+ action.getFunctionBody ().front ().getArguments (),
1087
+ matchOutputMapping ))) {
1088
+ return WalkResult::interrupt ();
1065
1089
}
1066
1090
1067
1091
for (Operation &transform :
@@ -1082,6 +1106,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1082
1106
continue ;
1083
1107
}
1084
1108
}
1109
+ if (failed (detail::appendValueMappings (
1110
+ MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1111
+ action.getFunctionBody ().front ().getTerminator ()->getOperands (),
1112
+ state, getFlattenResults ()))) {
1113
+ emitDefiniteFailure ()
1114
+ << " action @" << action.getName ()
1115
+ << " has results associated with multiple payload entities, "
1116
+ " but flattening was not requested" ;
1117
+ return WalkResult::interrupt ();
1118
+ }
1085
1119
break ;
1086
1120
}
1087
1121
return WalkResult::advance ();
@@ -1096,6 +1130,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1096
1130
// by actions, are invalidated.
1097
1131
results.set (llvm::cast<OpResult>(getUpdated ()),
1098
1132
state.getPayloadOps (getRoot ()));
1133
+ for (auto &&[result, mapping] :
1134
+ llvm::zip_equal (getForwardedOutputs (), actionResultMapping)) {
1135
+ results.setMappedValues (result, mapping);
1136
+ }
1099
1137
return overallDiag;
1100
1138
}
1101
1139
@@ -1108,7 +1146,8 @@ void transform::ForeachMatchOp::getEffects(
1108
1146
}
1109
1147
1110
1148
consumesHandle (getRoot (), effects);
1111
- producesHandle (getUpdated (), effects);
1149
+ onlyReadsHandle (getForwardedInputs (), effects);
1150
+ producesHandle (getResults (), effects);
1112
1151
modifiesPayload (effects);
1113
1152
}
1114
1153
@@ -1224,6 +1263,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1224
1263
StringAttr::get (getContext (), TransformDialect::kArgConsumedAttrName );
1225
1264
for (auto &&[matcher, action] :
1226
1265
llvm::zip_equal (getMatchers (), getActions ())) {
1266
+ // Presence and typing.
1227
1267
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1228
1268
symbolTable.lookupNearestSymbolFrom (getOperation (),
1229
1269
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1290,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1250
1290
return failure ();
1251
1291
}
1252
1292
1253
- ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes ();
1254
- ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes ();
1293
+ // Input -> matcher forwarding.
1294
+ TypeRange operandTypes = getOperandTypes ();
1295
+ TypeRange matcherArguments = matcherSymbol.getArgumentTypes ();
1296
+ if (operandTypes.size () != matcherArguments.size ()) {
1297
+ InFlightDiagnostic diag =
1298
+ emitError () << " the number of operands (" << operandTypes.size ()
1299
+ << " ) doesn't match the number of matcher arguments ("
1300
+ << matcherArguments.size () << " ) for " << matcher;
1301
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1302
+ return diag;
1303
+ }
1304
+ for (auto &&[i, operand, argument] :
1305
+ llvm::enumerate (operandTypes, matcherArguments)) {
1306
+ if (matcherSymbol.getArgAttr (i, consumedAttr)) {
1307
+ InFlightDiagnostic diag =
1308
+ emitOpError ()
1309
+ << " does not expect matcher symbol to consume its operand #" << i;
1310
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1311
+ return diag;
1312
+ }
1313
+
1314
+ if (implementSameTransformInterface (operand, argument))
1315
+ continue ;
1316
+
1317
+ InFlightDiagnostic diag =
1318
+ emitError ()
1319
+ << " mismatching type interfaces for operand and matcher argument #"
1320
+ << i << " of matcher " << matcher;
1321
+ diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1322
+ return diag;
1323
+ }
1324
+
1325
+ // Matcher -> action forwarding.
1326
+ TypeRange matcherResults = matcherSymbol.getResultTypes ();
1327
+ TypeRange actionArguments = actionSymbol.getArgumentTypes ();
1255
1328
if (matcherResults.size () != actionArguments.size ()) {
1256
1329
return emitError () << " mismatching number of matcher results and "
1257
1330
" action arguments between "
@@ -1265,31 +1338,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1265
1338
1266
1339
return emitError () << " mismatching type interfaces for matcher result "
1267
1340
" and action argument #"
1268
- << i;
1341
+ << i << " of matcher " << matcher << " and action "
1342
+ << action;
1269
1343
}
1270
1344
1271
- if (!actionSymbol.getResultTypes ().empty ()) {
1345
+ // Action -> result forwarding.
1346
+ TypeRange actionResults = actionSymbol.getResultTypes ();
1347
+ auto resultTypes = TypeRange (getResultTypes ()).drop_front ();
1348
+ if (actionResults.size () != resultTypes.size ()) {
1272
1349
InFlightDiagnostic diag =
1273
- emitError () << " action symbol is not expected to have results" ;
1350
+ emitError () << " the number of action results ("
1351
+ << actionResults.size () << " ) for " << action
1352
+ << " doesn't match the number of extra op results ("
1353
+ << resultTypes.size () << " )" ;
1274
1354
diag.attachNote (actionSymbol->getLoc ()) << " symbol declaration" ;
1275
1355
return diag;
1276
1356
}
1357
+ for (auto &&[i, resultType, actionType] :
1358
+ llvm::enumerate (resultTypes, actionResults)) {
1359
+ if (implementSameTransformInterface (resultType, actionType))
1360
+ continue ;
1277
1361
1278
- if (matcherSymbol.getArgumentTypes ().size () != 1 ||
1279
- !implementSameTransformInterface (matcherSymbol.getArgumentTypes ()[0 ],
1280
- getRoot ().getType ())) {
1281
- InFlightDiagnostic diag =
1282
- emitOpError () << " expects matcher symbol to have one argument with "
1283
- " the same transform interface as the first operand" ;
1284
- diag.attachNote (matcherSymbol->getLoc ()) << " symbol declaration" ;
1285
- return diag;
1286
- }
1287
-
1288
- if (matcherSymbol.getArgAttr (0 , consumedAttr)) {
1289
1362
InFlightDiagnostic diag =
1290
- emitOpError ()
1291
- << " does not expect matcher symbol to consume its operand " ;
1292
- diag.attachNote (matcherSymbol ->getLoc ()) << " symbol declaration" ;
1363
+ emitError () << " mismatching type interfaces for action result # " << i
1364
+ << " of action " << action << " and op result " ;
1365
+ diag.attachNote (actionSymbol ->getLoc ()) << " symbol declaration" ;
1293
1366
return diag;
1294
1367
}
1295
1368
}
0 commit comments