@@ -6231,6 +6231,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6231
6231
// Adjoint values of dominated active values are passed as pullback block
6232
6232
// arguments.
6233
6233
DominanceOrder domOrder (original.getEntryBlock (), domInfo);
6234
+ // Keep track of visited values.
6235
+ SmallPtrSet<SILValue, 8 > visited;
6234
6236
while (auto *bb = domOrder.getNext ()) {
6235
6237
auto &bbActiveValues = activeValues[bb];
6236
6238
// If the current block has an immediate dominator, append the immediate
@@ -6240,13 +6242,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6240
6242
bbActiveValues.append (domBBActiveValues.begin (),
6241
6243
domBBActiveValues.end ());
6242
6244
}
6243
- SmallPtrSet<SILValue, 8 > visited (bbActiveValues.begin (),
6244
- bbActiveValues.end ());
6245
- // Register a value as active if it has not yet been visited.
6246
6245
bool diagnosedActiveEnumValue = false ;
6247
- auto addActiveValue = [&](SILValue v) {
6246
+ // Mark the activity of a value if it has not yet been visited.
6247
+ auto markValueActivity = [&](SILValue v) {
6248
6248
if (visited.count (v))
6249
6249
return ;
6250
+ visited.insert (v);
6250
6251
// Diagnose active enum values. Differentiation of enum values requires
6251
6252
// special adjoint value handling and is not yet supported. Diagnose
6252
6253
// only the first active enum value to prevent too many diagnostics.
@@ -6262,17 +6263,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6262
6263
// become projections into their adjoint base buffer.
6263
6264
if (Projection::isAddressProjection (v))
6264
6265
return ;
6265
- visited.insert (v);
6266
6266
bbActiveValues.push_back (v);
6267
6267
};
6268
- // Register bb arguments and all instruction operands/results.
6268
+ // Visit bb arguments and all instruction operands/results.
6269
+ for (auto *arg : bb->getArguments ())
6270
+ if (getActivityInfo ().isActive (arg, getIndices ()))
6271
+ markValueActivity (arg);
6269
6272
for (auto &inst : *bb) {
6270
6273
for (auto op : inst.getOperandValues ())
6271
6274
if (getActivityInfo ().isActive (op, getIndices ()))
6272
- addActiveValue (op);
6275
+ markValueActivity (op);
6273
6276
for (auto result : inst.getResults ())
6274
6277
if (getActivityInfo ().isActive (result, getIndices ()))
6275
- addActiveValue (result);
6278
+ markValueActivity (result);
6276
6279
}
6277
6280
domOrder.pushChildren (bb);
6278
6281
if (errorOccurred)
0 commit comments