@@ -6311,6 +6311,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6311
6311
// Adjoint values of dominated active values are passed as pullback block
6312
6312
// arguments.
6313
6313
DominanceOrder domOrder (original.getEntryBlock (), domInfo);
6314
+ // Keep track of visited values.
6315
+ SmallPtrSet<SILValue, 8 > visited;
6314
6316
while (auto *bb = domOrder.getNext ()) {
6315
6317
auto &bbActiveValues = activeValues[bb];
6316
6318
// If the current block has an immediate dominator, append the immediate
@@ -6320,13 +6322,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6320
6322
bbActiveValues.append (domBBActiveValues.begin (),
6321
6323
domBBActiveValues.end ());
6322
6324
}
6323
- SmallPtrSet<SILValue, 8 > visited (bbActiveValues.begin (),
6324
- bbActiveValues.end ());
6325
- // Register a value as active if it has not yet been visited.
6326
6325
bool diagnosedActiveEnumValue = false ;
6327
- auto addActiveValue = [&](SILValue v) {
6326
+ // Mark the activity of a value if it has not yet been visited.
6327
+ auto markValueActivity = [&](SILValue v) {
6328
6328
if (visited.count (v))
6329
6329
return ;
6330
+ visited.insert (v);
6330
6331
// Diagnose active enum values. Differentiation of enum values requires
6331
6332
// special adjoint value handling and is not yet supported. Diagnose
6332
6333
// only the first active enum value to prevent too many diagnostics.
@@ -6342,17 +6343,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6342
6343
// become projections into their adjoint base buffer.
6343
6344
if (Projection::isAddressProjection (v))
6344
6345
return ;
6345
- visited.insert (v);
6346
6346
bbActiveValues.push_back (v);
6347
6347
};
6348
- // Register bb arguments and all instruction operands/results.
6348
+ // Visit bb arguments and all instruction operands/results.
6349
+ for (auto *arg : bb->getArguments ())
6350
+ if (getActivityInfo ().isActive (arg, getIndices ()))
6351
+ markValueActivity (arg);
6349
6352
for (auto &inst : *bb) {
6350
6353
for (auto op : inst.getOperandValues ())
6351
6354
if (getActivityInfo ().isActive (op, getIndices ()))
6352
- addActiveValue (op);
6355
+ markValueActivity (op);
6353
6356
for (auto result : inst.getResults ())
6354
6357
if (getActivityInfo ().isActive (result, getIndices ()))
6355
- addActiveValue (result);
6358
+ markValueActivity (result);
6356
6359
}
6357
6360
domOrder.pushChildren (bb);
6358
6361
if (errorOccurred)
0 commit comments