Skip to content

Commit e8ee0fe

Browse files
committed
[AutoDiff] Unify adjoint entry block generation with the general logic.
In control flow AD support, the logic that handles non-entry adjoint block generation should not be different than the general logic that generates intermediate blocks and processes active values. The only special thing about the adjoint entry block is that it need no phi args because the original block (the original exit) has no successors. This PR unifies the block generation logic, moving us a bit closer to the memory leak fix.
1 parent 045e192 commit e8ee0fe

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4130,6 +4130,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41304130
: nullptr;
41314131
Lowering::GenericContextScope genericContextScope(
41324132
getContext().getTypeConverter(), adjGenSig);
4133+
auto origExitIt = original.findReturnBB();
4134+
assert(!origExitIt.isEnd() &&
4135+
"Functions without returns must have been diagnosed");
4136+
auto *origExit = &*origExitIt;
41334137

41344138
// Get dominated active values in original blocks.
41354139
// Adjoint values of dominated active values are passed as adjoint block
@@ -4168,18 +4172,29 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41684172
}
41694173
domOrder.pushChildren(bb);
41704174
}
4171-
4175+
41724176
// Create adjoint blocks and arguments, visiting original blocks in
41734177
// post-order.
41744178
for (auto *origBB : postOrderInfo->getPostOrder()) {
41754179
auto *adjointBB = adjoint.createBasicBlock();
41764180
adjointBBMap.insert({origBB, adjointBB});
4177-
// If adjoint block is the adjoint entry, continue.
4178-
if (adjointBB->isEntry())
4179-
continue;
4180-
// Otherwise, add a pullback struct argument to the adjoint block.
41814181
auto pbStructLoweredType =
41824182
remapType(getPullbackInfo().getPullbackStructLoweredType(origBB));
4183+
// If the BB is the original exit, then the adjoint block that we just
4184+
// createed must be the adjoint function's entry. We always know what an
4185+
// entry block's arguments should be, so we generate them and skip to the
4186+
// the next block.
4187+
if (origBB == origExit) {
4188+
assert(adjointBB->isEntry());
4189+
createEntryArguments(&getAdjoint());
4190+
auto *lastArg = adjointBB->getArguments().back();
4191+
assert(lastArg->getType() == pbStructLoweredType);
4192+
adjointPullbackStructArguments[origBB] = lastArg;
4193+
continue;
4194+
}
4195+
4196+
// Otherwise, we create a phi argument for the corresponding primal value
4197+
// struct, and then turn active values into phi arguments.
41834198
auto *pbStructArg = adjointBB->createPhiArgument(
41844199
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
41854200
adjointPullbackStructArguments[origBB] = pbStructArg;
@@ -4221,15 +4236,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42214236
}
42224237
}
42234238

4224-
auto *origEntry = original.getEntryBlock();
4225-
auto *origExit = &*original.findReturnBB();
42264239
auto *adjointEntry = adjoint.getEntryBlock();
4227-
createEntryArguments(&adjoint);
42284240
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
42294241
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults();
42304242
assert(adjParamArgs.size() == 2);
42314243
seed = adjParamArgs[0];
4232-
adjointPullbackStructArguments[origExit] = adjParamArgs[1];
42334244

42344245
// Assign adjoint for original result.
42354246
SmallVector<SILValue, 8> origFormalResults;
@@ -4433,6 +4444,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44334444
// Place the builder at the adjoint exit, i.e. the adjoint block
44344445
// corresponding to the original entry. Return the adjoints wrt parameters
44354446
// in the adjoint exit.
4447+
auto *origEntry = getOriginal().getEntryBlock();
44364448
builder.setInsertionPoint(getAdjointBlock(origEntry));
44374449

44384450
// This vector will contain all the materialized return elements.

0 commit comments

Comments
 (0)