@@ -4130,6 +4130,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4130
4130
: nullptr ;
4131
4131
Lowering::GenericContextScope genericContextScope (
4132
4132
getContext ().getTypeConverter (), adjGenSig);
4133
+ auto origExitIt = original.findReturnBB ();
4134
+ assert (!origExitIt.isEnd () &&
4135
+ " Functions without returns must have been diagnosed" );
4136
+ auto *origExit = &*origExitIt;
4133
4137
4134
4138
// Get dominated active values in original blocks.
4135
4139
// Adjoint values of dominated active values are passed as adjoint block
@@ -4168,18 +4172,29 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4168
4172
}
4169
4173
domOrder.pushChildren (bb);
4170
4174
}
4171
-
4175
+
4172
4176
// Create adjoint blocks and arguments, visiting original blocks in
4173
4177
// post-order.
4174
4178
for (auto *origBB : postOrderInfo->getPostOrder ()) {
4175
4179
auto *adjointBB = adjoint.createBasicBlock ();
4176
4180
adjointBBMap.insert ({origBB, adjointBB});
4177
- // If adjoint block is the adjoint entry, continue.
4178
- if (adjointBB->isEntry ())
4179
- continue ;
4180
- // Otherwise, add a pullback struct argument to the adjoint block.
4181
4181
auto pbStructLoweredType =
4182
4182
remapType (getPullbackInfo ().getPullbackStructLoweredType (origBB));
4183
+ // If the BB is the original exit, then the adjoint block that we just
4184
+ // createed must be the adjoint function's entry. We always know what an
4185
+ // entry block's arguments should be, so we generate them and skip to the
4186
+ // the next block.
4187
+ if (origBB == origExit) {
4188
+ assert (adjointBB->isEntry ());
4189
+ createEntryArguments (&getAdjoint ());
4190
+ auto *lastArg = adjointBB->getArguments ().back ();
4191
+ assert (lastArg->getType () == pbStructLoweredType);
4192
+ adjointPullbackStructArguments[origBB] = lastArg;
4193
+ continue ;
4194
+ }
4195
+
4196
+ // Otherwise, we create a phi argument for the corresponding primal value
4197
+ // struct, and then turn active values into phi arguments.
4183
4198
auto *pbStructArg = adjointBB->createPhiArgument (
4184
4199
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
4185
4200
adjointPullbackStructArguments[origBB] = pbStructArg;
@@ -4221,15 +4236,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4221
4236
}
4222
4237
}
4223
4238
4224
- auto *origEntry = original.getEntryBlock ();
4225
- auto *origExit = &*original.findReturnBB ();
4226
4239
auto *adjointEntry = adjoint.getEntryBlock ();
4227
- createEntryArguments (&adjoint);
4228
4240
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4229
4241
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults ();
4230
4242
assert (adjParamArgs.size () == 2 );
4231
4243
seed = adjParamArgs[0 ];
4232
- adjointPullbackStructArguments[origExit] = adjParamArgs[1 ];
4233
4244
4234
4245
// Assign adjoint for original result.
4235
4246
SmallVector<SILValue, 8 > origFormalResults;
@@ -4433,6 +4444,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4433
4444
// Place the builder at the adjoint exit, i.e. the adjoint block
4434
4445
// corresponding to the original entry. Return the adjoints wrt parameters
4435
4446
// in the adjoint exit.
4447
+ auto *origEntry = getOriginal ().getEntryBlock ();
4436
4448
builder.setInsertionPoint (getAdjointBlock (origEntry));
4437
4449
4438
4450
// This vector will contain all the materialized return elements.
0 commit comments