Skip to content

Commit 5ac83e2

Browse files
authored
[AutoDiff] Unify adjoint entry block generation with the general logic. (#25255)
In control flow AD support, the logic that handles entry adjoint block generation should not be different than the general logic that generates intermediate blocks and processes active values. The only special thing about the entry adjoint block is that it need no phi args because its corresponding original block (the original exit) has no successors. This PR unifies the block generation logic, moving us a bit closer to the memory leak fix.
1 parent 1c1305c commit 5ac83e2

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4131,6 +4131,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41314131
: nullptr;
41324132
Lowering::GenericContextScope genericContextScope(
41334133
getContext().getTypeConverter(), adjGenSig);
4134+
auto origExitIt = original.findReturnBB();
4135+
assert(origExitIt != original.end() &&
4136+
"Functions without returns must have been diagnosed");
4137+
auto *origExit = &*origExitIt;
41344138

41354139
// Get dominated active values in original blocks.
41364140
// Adjoint values of dominated active values are passed as adjoint block
@@ -4175,12 +4179,22 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41754179
for (auto *origBB : postOrderInfo->getPostOrder()) {
41764180
auto *adjointBB = adjoint.createBasicBlock();
41774181
adjointBBMap.insert({origBB, adjointBB});
4178-
// If adjoint block is the adjoint entry, continue.
4179-
if (adjointBB->isEntry())
4180-
continue;
4181-
// Otherwise, add a pullback struct argument to the adjoint block.
41824182
auto pbStructLoweredType =
41834183
remapType(getPullbackInfo().getPullbackStructLoweredType(origBB));
4184+
// If the BB is the original exit, then the adjoint block that we just
4185+
// created must be the adjoint function's entry. For the adjoint entry,
4186+
// create entry arguments and continue to the next block.
4187+
if (origBB == origExit) {
4188+
assert(adjointBB->isEntry());
4189+
createEntryArguments(&getAdjoint());
4190+
auto *lastArg = adjointBB->getArguments().back();
4191+
assert(lastArg->getType() == pbStructLoweredType);
4192+
adjointPullbackStructArguments[origBB] = lastArg;
4193+
continue;
4194+
}
4195+
4196+
// Otherwise, we create a phi argument for the corresponding pullback
4197+
// struct, and handle dominated active values/buffers.
41844198
auto *pbStructArg = adjointBB->createPhiArgument(
41854199
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
41864200
adjointPullbackStructArguments[origBB] = pbStructArg;
@@ -4222,15 +4236,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42224236
}
42234237
}
42244238

4225-
auto *origEntry = original.getEntryBlock();
4226-
auto *origExit = &*original.findReturnBB();
42274239
auto *adjointEntry = adjoint.getEntryBlock();
4228-
createEntryArguments(&adjoint);
42294240
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
42304241
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults();
42314242
assert(adjParamArgs.size() == 2);
42324243
seed = adjParamArgs[0];
4233-
adjointPullbackStructArguments[origExit] = adjParamArgs[1];
42344244

42354245
// Assign adjoint for original result.
42364246
SmallVector<SILValue, 8> origFormalResults;
@@ -4434,6 +4444,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44344444
// Place the builder at the adjoint exit, i.e. the adjoint block
44354445
// corresponding to the original entry. Return the adjoints wrt parameters
44364446
// in the adjoint exit.
4447+
auto *origEntry = getOriginal().getEntryBlock();
44374448
builder.setInsertionPoint(getAdjointBlock(origEntry));
44384449

44394450
// This vector will contain all the materialized return elements.

0 commit comments

Comments
 (0)