@@ -4131,6 +4131,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4131
4131
: nullptr ;
4132
4132
Lowering::GenericContextScope genericContextScope (
4133
4133
getContext ().getTypeConverter (), adjGenSig);
4134
+ auto origExitIt = original.findReturnBB ();
4135
+ assert (origExitIt != original.end () &&
4136
+ " Functions without returns must have been diagnosed" );
4137
+ auto *origExit = &*origExitIt;
4134
4138
4135
4139
// Get dominated active values in original blocks.
4136
4140
// Adjoint values of dominated active values are passed as adjoint block
@@ -4175,12 +4179,22 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4175
4179
for (auto *origBB : postOrderInfo->getPostOrder ()) {
4176
4180
auto *adjointBB = adjoint.createBasicBlock ();
4177
4181
adjointBBMap.insert ({origBB, adjointBB});
4178
- // If adjoint block is the adjoint entry, continue.
4179
- if (adjointBB->isEntry ())
4180
- continue ;
4181
- // Otherwise, add a pullback struct argument to the adjoint block.
4182
4182
auto pbStructLoweredType =
4183
4183
remapType (getPullbackInfo ().getPullbackStructLoweredType (origBB));
4184
+ // If the BB is the original exit, then the adjoint block that we just
4185
+ // created must be the adjoint function's entry. For the adjoint entry,
4186
+ // create entry arguments and continue to the next block.
4187
+ if (origBB == origExit) {
4188
+ assert (adjointBB->isEntry ());
4189
+ createEntryArguments (&getAdjoint ());
4190
+ auto *lastArg = adjointBB->getArguments ().back ();
4191
+ assert (lastArg->getType () == pbStructLoweredType);
4192
+ adjointPullbackStructArguments[origBB] = lastArg;
4193
+ continue ;
4194
+ }
4195
+
4196
+ // Otherwise, we create a phi argument for the corresponding pullback
4197
+ // struct, and handle dominated active values/buffers.
4184
4198
auto *pbStructArg = adjointBB->createPhiArgument (
4185
4199
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
4186
4200
adjointPullbackStructArguments[origBB] = pbStructArg;
@@ -4222,15 +4236,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4222
4236
}
4223
4237
}
4224
4238
4225
- auto *origEntry = original.getEntryBlock ();
4226
- auto *origExit = &*original.findReturnBB ();
4227
4239
auto *adjointEntry = adjoint.getEntryBlock ();
4228
- createEntryArguments (&adjoint);
4229
4240
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4230
4241
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults ();
4231
4242
assert (adjParamArgs.size () == 2 );
4232
4243
seed = adjParamArgs[0 ];
4233
- adjointPullbackStructArguments[origExit] = adjParamArgs[1 ];
4234
4244
4235
4245
// Assign adjoint for original result.
4236
4246
SmallVector<SILValue, 8 > origFormalResults;
@@ -4434,6 +4444,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4434
4444
// Place the builder at the adjoint exit, i.e. the adjoint block
4435
4445
// corresponding to the original entry. Return the adjoints wrt parameters
4436
4446
// in the adjoint exit.
4447
+ auto *origEntry = getOriginal ().getEntryBlock ();
4437
4448
builder.setInsertionPoint (getAdjointBlock (origEntry));
4438
4449
4439
4450
// This vector will contain all the materialized return elements.
0 commit comments