30
30
31
31
using namespace mlir ;
32
32
33
+ // ===----------------------------------------------------------------------===//
34
+ // Helper functions
35
+ // ===----------------------------------------------------------------------===//
36
+
37
+ // / Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
38
+ // / properly dominates `b` and `b` is not inside `a`.
39
+ static bool happensBefore (Operation *a, Operation *b) {
40
+ do {
41
+ if (a->isProperAncestor (b))
42
+ return false ;
43
+ if (Operation *bAncestor = a->getBlock ()->findAncestorOpInBlock (*b)) {
44
+ return a->isBeforeInBlock (bAncestor);
45
+ }
46
+ } while ((a = a->getParentOp ()));
47
+ return false ;
48
+ }
49
+
33
50
// ===----------------------------------------------------------------------===//
34
51
// TransformState
35
52
// ===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
44
61
topLevelMappedValues.reserve (extraMappings.size ());
45
62
for (ArrayRef<MappedValue> mapping : extraMappings)
46
63
topLevelMappedValues.push_back (mapping);
47
-
48
- auto result =
49
- mappings.insert (std::make_pair (region, std::make_unique<Mappings>()));
50
- assert (result.second && " the region scope is already present" );
51
- (void )result;
52
- #if LLVM_ENABLE_ABI_BREAKING_CHECKS
53
- regionStack.push_back (region);
54
- #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
64
+ if (region) {
65
+ RegionScope *scope = new RegionScope (*this , *region);
66
+ topLevelRegionScope.reset (scope);
67
+ }
55
68
}
56
69
57
70
Operation *transform::TransformState::getTopLevel () const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
811
824
LLVM_DEBUG (DBGS () << " Failing Top-level payload:\n " ; getTopLevel ()->print (
812
825
llvm::dbgs (), mlir::OpPrintingFlags ().printGenericOpForm ()););
813
826
});
827
+
828
+ // Set current transform op.
829
+ regionStack.back ()->currentTransform = transform;
830
+
831
+ // Expensive checks to detect invalid transform IR.
814
832
if (options.getExpensiveChecksEnabled ()) {
815
833
FULL_LDBG (" ExpensiveChecksEnabled\n " );
816
834
if (failed (checkAndRecordHandleInvalidation (transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
899
917
}
900
918
901
919
// Prepare rewriter and listener.
902
- transform::ErrorCheckingTrackingListener trackingListener (*this , transform);
920
+ TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
921
+ // Skip handle if it is dead.
922
+ auto scopeIt =
923
+ llvm::find_if (llvm::reverse (regionStack), [&](RegionScope *scope) {
924
+ return handle.getParentRegion () == scope->region ;
925
+ });
926
+ assert (scopeIt != regionStack.rend () &&
927
+ " could not find region scope for handle" );
928
+ RegionScope *scope = *scopeIt;
929
+ for (Operation *user : handle.getUsers ()) {
930
+ if (user != scope->currentTransform &&
931
+ !happensBefore (user, scope->currentTransform ))
932
+ return false ;
933
+ }
934
+ return true ;
935
+ };
936
+ transform::ErrorCheckingTrackingListener trackingListener (*this , transform,
937
+ skipHandleFn);
903
938
transform::TransformRewriter rewriter (transform->getContext (),
904
939
&trackingListener);
905
940
@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
1040
1075
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1041
1076
1042
1077
state.mappings .erase (region);
1043
-
1044
- #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1045
1078
state.regionStack .pop_back ();
1046
- #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1047
1079
}
1048
1080
1049
1081
// ===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
1150
1182
// ===----------------------------------------------------------------------===//
1151
1183
1152
1184
transform::TrackingListener::TrackingListener (TransformState &state,
1153
- TransformOpInterface op)
1154
- : TransformState::Extension(state), transformOp(op) {
1185
+ TransformOpInterface op,
1186
+ SkipHandleFn skipHandleFn)
1187
+ : TransformState::Extension(state), transformOp(op),
1188
+ skipHandleFn(skipHandleFn) {
1155
1189
if (op) {
1156
1190
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands ()) {
1157
1191
consumedHandles.insert (opOperand->get ());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
1251
1285
});
1252
1286
}
1253
1287
1254
- // / Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
1255
- // / properly dominates `b` and `b` is not inside `a`.
1256
- static bool happensBefore (Operation *a, Operation *b) {
1257
- do {
1258
- if (a->isProperAncestor (b))
1259
- return false ;
1260
- if (Operation *bAncestor = a->getBlock ()->findAncestorOpInBlock (*b)) {
1261
- return a->isBeforeInBlock (bAncestor);
1262
- }
1263
- } while ((a = a->getParentOp ()));
1264
- return false ;
1265
- }
1266
-
1267
1288
void transform::TrackingListener::notifyOperationReplaced (
1268
1289
Operation *op, ValueRange newValues) {
1269
1290
assert (op->getNumResults () == newValues.size () &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
1295
1316
[&](Value h) { return consumedHandles.contains (h); });
1296
1317
};
1297
1318
1298
- // Helper function to check if the handle is alive.
1299
- auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
1300
- for (Value v : opHandles) {
1301
- for (OpOperand &use : v.getUses ())
1302
- if (use.getOwner () != transformOp &&
1303
- !happensBefore (use.getOwner (), transformOp))
1304
- return &use;
1305
- }
1306
- return std::nullopt;
1307
- }();
1308
-
1309
- if (!firstAliveUser.has_value () || handleWasConsumed ()) {
1319
+ // Check if there are any handles that must be updated.
1320
+ Value aliveHandle;
1321
+ if (skipHandleFn) {
1322
+ auto it =
1323
+ llvm::find_if (opHandles, [&](Value v) { return !skipHandleFn (v); });
1324
+ if (it != opHandles.end ())
1325
+ aliveHandle = *it;
1326
+ } else if (!opHandles.empty ()) {
1327
+ aliveHandle = opHandles.front ();
1328
+ }
1329
+ if (!aliveHandle || handleWasConsumed ()) {
1310
1330
// The op is tracked but the corresponding handles are dead or were
1311
1331
// consumed. Drop the op form the mapping.
1312
1332
(void )replacePayloadOp (op, nullptr );
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
1319
1339
// If the op is tracked but no replacement op was found, send a
1320
1340
// notification.
1321
1341
if (!diag.succeeded ()) {
1322
- diag.attachNote ((*firstAliveUser)->getOwner ()->getLoc ())
1323
- << " replacement is required because alive handle(s) exist "
1324
- << " (first use in this op as operand number "
1325
- << (*firstAliveUser)->getOperandNumber () << " )" ;
1342
+ diag.attachNote (aliveHandle.getLoc ())
1343
+ << " replacement is required because this handle must be updated" ;
1326
1344
notifyPayloadReplacementNotFound (op, newValues, std::move (diag));
1327
1345
(void )replacePayloadOp (op, nullptr );
1328
1346
return ;
0 commit comments