Skip to content

Commit 8225e02

Browse files
authored
[AutoDiff] Minor activity analysis changes. (#28301)
Hoist activity marking visited value set out of loop over original bbs. This is safe because bbs directly start with dominator bbs's active values. Visit bb arguments for activity marking. This was accidentally deleted in #28225. Re-adding the logic doesn't seem to affect any tests.
1 parent 04dca63 commit 8225e02

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6231,6 +6231,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62316231
// Adjoint values of dominated active values are passed as pullback block
62326232
// arguments.
62336233
DominanceOrder domOrder(original.getEntryBlock(), domInfo);
6234+
// Keep track of visited values.
6235+
SmallPtrSet<SILValue, 8> visited;
62346236
while (auto *bb = domOrder.getNext()) {
62356237
auto &bbActiveValues = activeValues[bb];
62366238
// If the current block has an immediate dominator, append the immediate
@@ -6240,13 +6242,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62406242
bbActiveValues.append(domBBActiveValues.begin(),
62416243
domBBActiveValues.end());
62426244
}
6243-
SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(),
6244-
bbActiveValues.end());
6245-
// Register a value as active if it has not yet been visited.
62466245
bool diagnosedActiveEnumValue = false;
6247-
auto addActiveValue = [&](SILValue v) {
6246+
// Mark the activity of a value if it has not yet been visited.
6247+
auto markValueActivity = [&](SILValue v) {
62486248
if (visited.count(v))
62496249
return;
6250+
visited.insert(v);
62506251
// Diagnose active enum values. Differentiation of enum values requires
62516252
// special adjoint value handling and is not yet supported. Diagnose
62526253
// only the first active enum value to prevent too many diagnostics.
@@ -6262,17 +6263,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62626263
// become projections into their adjoint base buffer.
62636264
if (Projection::isAddressProjection(v))
62646265
return;
6265-
visited.insert(v);
62666266
bbActiveValues.push_back(v);
62676267
};
6268-
// Register bb arguments and all instruction operands/results.
6268+
// Visit bb arguments and all instruction operands/results.
6269+
for (auto *arg : bb->getArguments())
6270+
if (getActivityInfo().isActive(arg, getIndices()))
6271+
markValueActivity(arg);
62696272
for (auto &inst : *bb) {
62706273
for (auto op : inst.getOperandValues())
62716274
if (getActivityInfo().isActive(op, getIndices()))
6272-
addActiveValue(op);
6275+
markValueActivity(op);
62736276
for (auto result : inst.getResults())
62746277
if (getActivityInfo().isActive(result, getIndices()))
6275-
addActiveValue(result);
6278+
markValueActivity(result);
62766279
}
62776280
domOrder.pushChildren(bb);
62786281
if (errorOccurred)

0 commit comments

Comments
 (0)