@@ -4378,8 +4378,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4378
4378
// / any error occurs.
4379
4379
bool run () {
4380
4380
auto &original = getOriginal ();
4381
- auto &adjoint = getPullback ();
4382
- auto adjLoc = getPullback ().getLocation ();
4381
+ auto &pullback = getPullback ();
4382
+ auto pbLoc = getPullback ().getLocation ();
4383
4383
LLVM_DEBUG (getADDebugStream () << " Running PullbackEmitter on\n " << original);
4384
4384
4385
4385
auto *adjGenEnv = getPullback ().getGenericEnvironment ();
@@ -4446,7 +4446,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4446
4446
if (errorOccurred)
4447
4447
return true ;
4448
4448
4449
- // Create adjoint blocks and arguments, visiting original blocks in
4449
+ // Create pullback blocks and arguments, visiting original blocks in
4450
4450
// post-order post-dominance order.
4451
4451
SmallVector<SILBasicBlock *, 8 > postOrderPostDomOrder;
4452
4452
// Start from the root node, which may have a marker `nullptr` block if
@@ -4462,17 +4462,17 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4462
4462
postOrderPostDomOrder.push_back (origBB);
4463
4463
}
4464
4464
for (auto *origBB : postOrderPostDomOrder) {
4465
- auto *adjointBB = adjoint .createBasicBlock ();
4466
- adjointBBMap.insert ({origBB, adjointBB });
4465
+ auto *pullbackBB = pullback .createBasicBlock ();
4466
+ adjointBBMap.insert ({origBB, pullbackBB });
4467
4467
auto pbStructLoweredType =
4468
4468
remapType (getPullbackInfo ().getPullbackStructLoweredType (origBB));
4469
- // If the BB is the original exit, then the adjoint block that we just
4470
- // created must be the adjoint function's entry. For the adjoint entry,
4469
+ // If the BB is the original exit, then the pullback block that we just
4470
+ // created must be the pullback function's entry. For the pullback entry,
4471
4471
// create entry arguments and continue to the next block.
4472
4472
if (origBB == origExit) {
4473
- assert (adjointBB ->isEntry ());
4473
+ assert (pullbackBB ->isEntry ());
4474
4474
createEntryArguments (&getPullback ());
4475
- auto *lastArg = adjointBB ->getArguments ().back ();
4475
+ auto *lastArg = pullbackBB ->getArguments ().back ();
4476
4476
assert (lastArg->getType () == pbStructLoweredType);
4477
4477
pullbackStructArguments[origBB] = lastArg;
4478
4478
continue ;
@@ -4492,27 +4492,29 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4492
4492
if (activeValue->getType ().isAddress ()) {
4493
4493
// Allocate and zero initialize a new local buffer using
4494
4494
// `getAdjointBuffer`.
4495
- builder.setInsertionPoint (adjoint .getEntryBlock ());
4495
+ builder.setInsertionPoint (pullback .getEntryBlock ());
4496
4496
getAdjointBuffer (origBB, activeValue);
4497
4497
} else {
4498
- // Create and register adjoint block argument for the active value.
4499
- auto *adjointArg = adjointBB ->createPhiArgument (
4498
+ // Create and register pullback block argument for the active value.
4499
+ auto *adjointArg = pullbackBB ->createPhiArgument (
4500
4500
getRemappedTangentType (activeValue->getType ()),
4501
4501
ValueOwnershipKind::Guaranteed);
4502
4502
activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg;
4503
4503
}
4504
4504
}
4505
4505
// Add a pullback struct argument.
4506
- auto *pbStructArg = adjointBB ->createPhiArgument (
4506
+ auto *pbStructArg = pullbackBB ->createPhiArgument (
4507
4507
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
4508
4508
pullbackStructArguments[origBB] = pbStructArg;
4509
- // - Create adjoint trampoline blocks for each successor block of the
4509
+ // - Create pullback trampoline blocks for each successor block of the
4510
4510
// original block. Adjoint trampoline blocks only have a pullback
4511
4511
// struct argument, and branch from the adjoint successor block to the
4512
4512
// adjoint original block, trampoline adjoint values of active values.
4513
4513
for (auto *succBB : origBB->getSuccessorBlocks ()) {
4514
- auto *adjointTrampolineBB = adjoint.createBasicBlockBefore (adjointBB);
4515
- pullbackTrampolineBBMap.insert ({{origBB, succBB}, adjointTrampolineBB});
4514
+ auto *pullbackTrampolineBB =
4515
+ pullback.createBasicBlockBefore (pullbackBB);
4516
+ pullbackTrampolineBBMap.insert ({{origBB, succBB},
4517
+ pullbackTrampolineBB});
4516
4518
// Get the enum element type (i.e. the pullback struct type). The enum
4517
4519
// element type may be boxed if the enum is indirect.
4518
4520
auto enumLoweredTy =
@@ -4521,16 +4523,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4521
4523
getPullbackInfo ().lookUpPredecessorEnumElement (origBB, succBB);
4522
4524
auto enumEltType = remapType (
4523
4525
enumLoweredTy.getEnumElementType (enumEltDecl, getModule ()));
4524
- adjointTrampolineBB ->createPhiArgument (enumEltType,
4526
+ pullbackTrampolineBB ->createPhiArgument (enumEltType,
4525
4527
ValueOwnershipKind::Guaranteed);
4526
4528
}
4527
4529
}
4528
4530
4529
- auto *adjointEntry = adjoint .getEntryBlock ();
4530
- // The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4531
- auto adjParamArgs = adjoint .getArgumentsWithoutIndirectResults ();
4532
- assert (adjParamArgs .size () == 2 );
4533
- seed = adjParamArgs [0 ];
4531
+ auto *pullbackEntry = pullback .getEntryBlock ();
4532
+ // The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4533
+ auto pbParamArgs = pullback .getArgumentsWithoutIndirectResults ();
4534
+ assert (pbParamArgs .size () == 2 );
4535
+ seed = pbParamArgs [0 ];
4534
4536
4535
4537
// Assign adjoint for original result.
4536
4538
SmallVector<SILValue, 8 > origFormalResults;
@@ -4550,22 +4552,22 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4550
4552
}
4551
4553
}
4552
4554
builder.setInsertionPoint (
4553
- adjointEntry , getNextFunctionLocalAllocationInsertionPoint ());
4555
+ pullbackEntry , getNextFunctionLocalAllocationInsertionPoint ());
4554
4556
if (seed->getType ().isAddress ()) {
4555
4557
// Create a local copy so that it can be written to by later adjoint
4556
4558
// zero'ing logic.
4557
- auto *seedBufCopy = builder.createAllocStack (adjLoc , seed->getType ());
4558
- builder.createCopyAddr (adjLoc , seed, seedBufCopy, IsNotTake,
4559
+ auto *seedBufCopy = builder.createAllocStack (pbLoc , seed->getType ());
4560
+ builder.createCopyAddr (pbLoc , seed, seedBufCopy, IsNotTake,
4559
4561
IsInitialization);
4560
4562
if (seed->getType ().isLoadable (builder.getFunction ()))
4561
- builder.createRetainValueAddr (adjLoc , seedBufCopy,
4563
+ builder.createRetainValueAddr (pbLoc , seedBufCopy,
4562
4564
builder.getDefaultAtomicity ());
4563
4565
ValueWithCleanup seedBufferCopyWithCleanup (
4564
4566
seedBufCopy, makeCleanup (seedBufCopy, emitCleanup));
4565
4567
setAdjointBuffer (origExit, origResult, seedBufferCopyWithCleanup);
4566
4568
functionLocalAllocations.push_back (seedBufferCopyWithCleanup);
4567
4569
} else {
4568
- builder.createRetainValue (adjLoc , seed, builder.getDefaultAtomicity ());
4570
+ builder.createRetainValue (pbLoc , seed, builder.getDefaultAtomicity ());
4569
4571
initializeAdjointValue (origExit, origResult, makeConcreteAdjointValue (
4570
4572
ValueWithCleanup (seed, makeCleanup (seed, emitCleanup))));
4571
4573
}
@@ -4615,7 +4617,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4615
4617
auto *predEnumField =
4616
4618
getPullbackInfo ().lookUpPullbackStructPredecessorField (bb);
4617
4619
auto *predEnumVal =
4618
- builder.createStructExtract (adjLoc , pbStructVal, predEnumField);
4620
+ builder.createStructExtract (pbLoc , pbStructVal, predEnumField);
4619
4621
4620
4622
// Propagate adjoint values from active basic block arguments to
4621
4623
// predecessor terminator operands.
@@ -4666,11 +4668,11 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4666
4668
if (activeValue->getType ().isObject ()) {
4667
4669
auto activeValueAdj = getAdjointValue (bb, activeValue);
4668
4670
auto concreteActiveValueAdj =
4669
- materializeAdjointDirect (activeValueAdj, adjLoc );
4671
+ materializeAdjointDirect (activeValueAdj, pbLoc );
4670
4672
// Emit cleanups for children.
4671
4673
if (auto *cleanup = concreteActiveValueAdj.getCleanup ()) {
4672
4674
cleanup->disable ();
4673
- cleanup->applyRecursively (builder, adjLoc );
4675
+ cleanup->applyRecursively (builder, pbLoc );
4674
4676
}
4675
4677
trampolineArguments.push_back (concreteActiveValueAdj);
4676
4678
// If the adjoint block does not yet have a registered adjoint
@@ -4694,7 +4696,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4694
4696
predAdjBuf.setCleanup (makeCleanupFromChildren (
4695
4697
{adjBuf.getCleanup (), predAdjBuf.getCleanup ()}));
4696
4698
builder.createCopyAddr (
4697
- adjLoc , adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
4699
+ pbLoc , adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
4698
4700
}
4699
4701
}
4700
4702
// Propagate pullback struct argument.
@@ -4705,14 +4707,14 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4705
4707
trampolineArguments.push_back (predPBStructVal);
4706
4708
} else {
4707
4709
auto *projectBox = adjointTrampolineBBBuilder.createProjectBox (
4708
- adjLoc , predPBStructVal, /* index*/ 0 );
4710
+ pbLoc , predPBStructVal, /* index*/ 0 );
4709
4711
auto *loadInst = adjointTrampolineBBBuilder.createLoad (
4710
- adjLoc , projectBox,
4711
- getBufferLOQ (projectBox->getType ().getASTType (), adjoint ));
4712
+ pbLoc , projectBox,
4713
+ getBufferLOQ (projectBox->getType ().getASTType (), pullback ));
4712
4714
trampolineArguments.push_back (loadInst);
4713
4715
}
4714
4716
// Branch from adjoint trampoline block to adjoint block.
4715
- adjointTrampolineBBBuilder.createBranch (adjLoc , adjointBB,
4717
+ adjointTrampolineBBBuilder.createBranch (pbLoc , adjointBB,
4716
4718
trampolineArguments);
4717
4719
}
4718
4720
auto *enumEltDecl =
@@ -4734,15 +4736,15 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4734
4736
SILBasicBlock *adjointSuccBB;
4735
4737
std::tie (enumEltDecl, adjointSuccBB) = adjointSuccessorCases.front ();
4736
4738
auto *predPBStructVal =
4737
- builder.createUncheckedEnumData (adjLoc , predEnumVal, enumEltDecl);
4738
- builder.createBranch (adjLoc , adjointSuccBB, {predPBStructVal});
4739
+ builder.createUncheckedEnumData (pbLoc , predEnumVal, enumEltDecl);
4740
+ builder.createBranch (pbLoc , adjointSuccBB, {predPBStructVal});
4739
4741
}
4740
4742
// - Otherwise, if the original block has multiple predecessors, then the
4741
4743
// adjoint block has multiple successors. Do `switch_enum` to branch on
4742
4744
// the predecessor enum values to adjoint successor blocks.
4743
4745
else {
4744
4746
builder.createSwitchEnum (
4745
- adjLoc , predEnumVal, /* DefaultBB*/ nullptr , adjointSuccessorCases);
4747
+ pbLoc , predEnumVal, /* DefaultBB*/ nullptr , adjointSuccessorCases);
4746
4748
}
4747
4749
}
4748
4750
@@ -4769,14 +4771,14 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4769
4771
auto origParam = origParams[parameterIndex];
4770
4772
if (origParam->getType ().isObject ()) {
4771
4773
auto adjVal = getAdjointValue (origEntry, origParam);
4772
- auto val = materializeAdjointDirect (adjVal, adjLoc );
4774
+ auto val = materializeAdjointDirect (adjVal, pbLoc );
4773
4775
if (auto *cleanup = val.getCleanup ()) {
4774
4776
LLVM_DEBUG (getADDebugStream () << " Disabling cleanup for "
4775
4777
<< val.getValue () << " for return\n " );
4776
4778
cleanup->disable ();
4777
4779
LLVM_DEBUG (getADDebugStream () << " Applying "
4778
4780
<< cleanup->getNumChildren () << " child cleanups\n " );
4779
- cleanup->applyRecursively (builder, adjLoc );
4781
+ cleanup->applyRecursively (builder, pbLoc );
4780
4782
}
4781
4783
retElts.push_back (val);
4782
4784
} else {
@@ -4795,12 +4797,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4795
4797
4796
4798
// Disable cleanup for original indirect parameter adjoint buffers.
4797
4799
// Copy them to adjoint indirect results.
4798
- assert (indParamAdjoints.size () == adjoint .getIndirectResults ().size () &&
4800
+ assert (indParamAdjoints.size () == pullback .getIndirectResults ().size () &&
4799
4801
" Indirect parameter adjoint count mismatch" );
4800
- for (auto pair : zip (indParamAdjoints, adjoint .getIndirectResults ())) {
4802
+ for (auto pair : zip (indParamAdjoints, pullback .getIndirectResults ())) {
4801
4803
auto &source = std::get<0 >(pair);
4802
4804
auto &dest = std::get<1 >(pair);
4803
- builder.createCopyAddr (adjLoc , source, dest, IsTake, IsInitialization);
4805
+ builder.createCopyAddr (pbLoc , source, dest, IsTake, IsInitialization);
4804
4806
if (auto *cleanup = source.getCleanup ())
4805
4807
cleanup->disable ();
4806
4808
}
@@ -4812,13 +4814,13 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4812
4814
// Buffers should not be allocated needlessly.
4813
4815
assert (!alloc.getValue ()->use_empty ());
4814
4816
if (auto *cleanup = alloc.getCleanup ())
4815
- cleanup->applyRecursively (builder, adjLoc );
4816
- builder.createDeallocStack (adjLoc , alloc);
4817
+ cleanup->applyRecursively (builder, pbLoc );
4818
+ builder.createDeallocStack (pbLoc , alloc);
4817
4819
}
4818
- builder.createReturn (adjLoc , joinElements (retElts, builder, adjLoc ));
4820
+ builder.createReturn (pbLoc , joinElements (retElts, builder, pbLoc ));
4819
4821
4820
- LLVM_DEBUG (getADDebugStream () << " Generated adjoint for "
4821
- << original.getName () << " :\n " << adjoint );
4822
+ LLVM_DEBUG (getADDebugStream () << " Generated pullback for "
4823
+ << original.getName () << " :\n " << pullback );
4822
4824
return errorOccurred;
4823
4825
}
4824
4826
0 commit comments