Skip to content

Commit a040d57

Browse files
bartchr808rxwei
authored andcommitted
---
yaml --- r: 340990 b: refs/heads/rxwei-patch-1 c: 58ccca2 h: refs/heads/master
1 parent 32c7b4a commit a040d57

File tree

2 files changed

+51
-49
lines changed

2 files changed

+51
-49
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: c1dfcde1a4ddc3a6aad97571aeb38fab5f9ce37d
1018+
refs/heads/rxwei-patch-1: 58ccca2f6a1c5f75d186cbbc5466596a6bfc8f71
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4378,8 +4378,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
43784378
/// any error occurs.
43794379
bool run() {
43804380
auto &original = getOriginal();
4381-
auto &adjoint = getPullback();
4382-
auto adjLoc = getPullback().getLocation();
4381+
auto &pullback = getPullback();
4382+
auto pbLoc = getPullback().getLocation();
43834383
LLVM_DEBUG(getADDebugStream() << "Running PullbackEmitter on\n" << original);
43844384

43854385
auto *adjGenEnv = getPullback().getGenericEnvironment();
@@ -4446,7 +4446,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
44464446
if (errorOccurred)
44474447
return true;
44484448

4449-
// Create adjoint blocks and arguments, visiting original blocks in
4449+
// Create pullback blocks and arguments, visiting original blocks in
44504450
// post-order post-dominance order.
44514451
SmallVector<SILBasicBlock *, 8> postOrderPostDomOrder;
44524452
// Start from the root node, which may have a marker `nullptr` block if
@@ -4462,17 +4462,17 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
44624462
postOrderPostDomOrder.push_back(origBB);
44634463
}
44644464
for (auto *origBB : postOrderPostDomOrder) {
4465-
auto *adjointBB = adjoint.createBasicBlock();
4466-
adjointBBMap.insert({origBB, adjointBB});
4465+
auto *pullbackBB = pullback.createBasicBlock();
4466+
adjointBBMap.insert({origBB, pullbackBB});
44674467
auto pbStructLoweredType =
44684468
remapType(getPullbackInfo().getPullbackStructLoweredType(origBB));
4469-
// If the BB is the original exit, then the adjoint block that we just
4470-
// created must be the adjoint function's entry. For the adjoint entry,
4469+
// If the BB is the original exit, then the pullback block that we just
4470+
// created must be the pullback function's entry. For the pullback entry,
44714471
// create entry arguments and continue to the next block.
44724472
if (origBB == origExit) {
4473-
assert(adjointBB->isEntry());
4473+
assert(pullbackBB->isEntry());
44744474
createEntryArguments(&getPullback());
4475-
auto *lastArg = adjointBB->getArguments().back();
4475+
auto *lastArg = pullbackBB->getArguments().back();
44764476
assert(lastArg->getType() == pbStructLoweredType);
44774477
pullbackStructArguments[origBB] = lastArg;
44784478
continue;
@@ -4492,27 +4492,29 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
44924492
if (activeValue->getType().isAddress()) {
44934493
// Allocate and zero initialize a new local buffer using
44944494
// `getAdjointBuffer`.
4495-
builder.setInsertionPoint(adjoint.getEntryBlock());
4495+
builder.setInsertionPoint(pullback.getEntryBlock());
44964496
getAdjointBuffer(origBB, activeValue);
44974497
} else {
4498-
// Create and register adjoint block argument for the active value.
4499-
auto *adjointArg = adjointBB->createPhiArgument(
4498+
// Create and register pullback block argument for the active value.
4499+
auto *adjointArg = pullbackBB->createPhiArgument(
45004500
getRemappedTangentType(activeValue->getType()),
45014501
ValueOwnershipKind::Guaranteed);
45024502
activeValueAdjointBBArgumentMap[{origBB, activeValue}] = adjointArg;
45034503
}
45044504
}
45054505
// Add a pullback struct argument.
4506-
auto *pbStructArg = adjointBB->createPhiArgument(
4506+
auto *pbStructArg = pullbackBB->createPhiArgument(
45074507
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
45084508
pullbackStructArguments[origBB] = pbStructArg;
4509-
// - Create adjoint trampoline blocks for each successor block of the
4509+
// - Create pullback trampoline blocks for each successor block of the
45104510
// original block. Adjoint trampoline blocks only have a pullback
45114511
// struct argument, and branch from the adjoint successor block to the
45124512
// adjoint original block, trampoline adjoint values of active values.
45134513
for (auto *succBB : origBB->getSuccessorBlocks()) {
4514-
auto *adjointTrampolineBB = adjoint.createBasicBlockBefore(adjointBB);
4515-
pullbackTrampolineBBMap.insert({{origBB, succBB}, adjointTrampolineBB});
4514+
auto *pullbackTrampolineBB =
4515+
pullback.createBasicBlockBefore(pullbackBB);
4516+
pullbackTrampolineBBMap.insert({{origBB, succBB},
4517+
pullbackTrampolineBB});
45164518
// Get the enum element type (i.e. the pullback struct type). The enum
45174519
// element type may be boxed if the enum is indirect.
45184520
auto enumLoweredTy =
@@ -4521,16 +4523,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
45214523
getPullbackInfo().lookUpPredecessorEnumElement(origBB, succBB);
45224524
auto enumEltType = remapType(
45234525
enumLoweredTy.getEnumElementType(enumEltDecl, getModule()));
4524-
adjointTrampolineBB->createPhiArgument(enumEltType,
4526+
pullbackTrampolineBB->createPhiArgument(enumEltType,
45254527
ValueOwnershipKind::Guaranteed);
45264528
}
45274529
}
45284530

4529-
auto *adjointEntry = adjoint.getEntryBlock();
4530-
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4531-
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults();
4532-
assert(adjParamArgs.size() == 2);
4533-
seed = adjParamArgs[0];
4531+
auto *pullbackEntry = pullback.getEntryBlock();
4532+
// The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
4533+
auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
4534+
assert(pbParamArgs.size() == 2);
4535+
seed = pbParamArgs[0];
45344536

45354537
// Assign adjoint for original result.
45364538
SmallVector<SILValue, 8> origFormalResults;
@@ -4550,22 +4552,22 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
45504552
}
45514553
}
45524554
builder.setInsertionPoint(
4553-
adjointEntry, getNextFunctionLocalAllocationInsertionPoint());
4555+
pullbackEntry, getNextFunctionLocalAllocationInsertionPoint());
45544556
if (seed->getType().isAddress()) {
45554557
// Create a local copy so that it can be written to by later adjoint
45564558
// zero'ing logic.
4557-
auto *seedBufCopy = builder.createAllocStack(adjLoc, seed->getType());
4558-
builder.createCopyAddr(adjLoc, seed, seedBufCopy, IsNotTake,
4559+
auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType());
4560+
builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
45594561
IsInitialization);
45604562
if (seed->getType().isLoadable(builder.getFunction()))
4561-
builder.createRetainValueAddr(adjLoc, seedBufCopy,
4563+
builder.createRetainValueAddr(pbLoc, seedBufCopy,
45624564
builder.getDefaultAtomicity());
45634565
ValueWithCleanup seedBufferCopyWithCleanup(
45644566
seedBufCopy, makeCleanup(seedBufCopy, emitCleanup));
45654567
setAdjointBuffer(origExit, origResult, seedBufferCopyWithCleanup);
45664568
functionLocalAllocations.push_back(seedBufferCopyWithCleanup);
45674569
} else {
4568-
builder.createRetainValue(adjLoc, seed, builder.getDefaultAtomicity());
4570+
builder.createRetainValue(pbLoc, seed, builder.getDefaultAtomicity());
45694571
initializeAdjointValue(origExit, origResult, makeConcreteAdjointValue(
45704572
ValueWithCleanup(seed, makeCleanup(seed, emitCleanup))));
45714573
}
@@ -4615,7 +4617,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
46154617
auto *predEnumField =
46164618
getPullbackInfo().lookUpPullbackStructPredecessorField(bb);
46174619
auto *predEnumVal =
4618-
builder.createStructExtract(adjLoc, pbStructVal, predEnumField);
4620+
builder.createStructExtract(pbLoc, pbStructVal, predEnumField);
46194621

46204622
// Propagate adjoint values from active basic block arguments to
46214623
// predecessor terminator operands.
@@ -4666,11 +4668,11 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
46664668
if (activeValue->getType().isObject()) {
46674669
auto activeValueAdj = getAdjointValue(bb, activeValue);
46684670
auto concreteActiveValueAdj =
4669-
materializeAdjointDirect(activeValueAdj, adjLoc);
4671+
materializeAdjointDirect(activeValueAdj, pbLoc);
46704672
// Emit cleanups for children.
46714673
if (auto *cleanup = concreteActiveValueAdj.getCleanup()) {
46724674
cleanup->disable();
4673-
cleanup->applyRecursively(builder, adjLoc);
4675+
cleanup->applyRecursively(builder, pbLoc);
46744676
}
46754677
trampolineArguments.push_back(concreteActiveValueAdj);
46764678
// If the adjoint block does not yet have a registered adjoint
@@ -4694,7 +4696,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
46944696
predAdjBuf.setCleanup(makeCleanupFromChildren(
46954697
{adjBuf.getCleanup(), predAdjBuf.getCleanup()}));
46964698
builder.createCopyAddr(
4697-
adjLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
4699+
pbLoc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
46984700
}
46994701
}
47004702
// Propagate pullback struct argument.
@@ -4705,14 +4707,14 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
47054707
trampolineArguments.push_back(predPBStructVal);
47064708
} else {
47074709
auto *projectBox = adjointTrampolineBBBuilder.createProjectBox(
4708-
adjLoc, predPBStructVal, /*index*/ 0);
4710+
pbLoc, predPBStructVal, /*index*/ 0);
47094711
auto *loadInst = adjointTrampolineBBBuilder.createLoad(
4710-
adjLoc, projectBox,
4711-
getBufferLOQ(projectBox->getType().getASTType(), adjoint));
4712+
pbLoc, projectBox,
4713+
getBufferLOQ(projectBox->getType().getASTType(), pullback));
47124714
trampolineArguments.push_back(loadInst);
47134715
}
47144716
// Branch from adjoint trampoline block to adjoint block.
4715-
adjointTrampolineBBBuilder.createBranch(adjLoc, adjointBB,
4717+
adjointTrampolineBBBuilder.createBranch(pbLoc, adjointBB,
47164718
trampolineArguments);
47174719
}
47184720
auto *enumEltDecl =
@@ -4734,15 +4736,15 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
47344736
SILBasicBlock *adjointSuccBB;
47354737
std::tie(enumEltDecl, adjointSuccBB) = adjointSuccessorCases.front();
47364738
auto *predPBStructVal =
4737-
builder.createUncheckedEnumData(adjLoc, predEnumVal, enumEltDecl);
4738-
builder.createBranch(adjLoc, adjointSuccBB, {predPBStructVal});
4739+
builder.createUncheckedEnumData(pbLoc, predEnumVal, enumEltDecl);
4740+
builder.createBranch(pbLoc, adjointSuccBB, {predPBStructVal});
47394741
}
47404742
// - Otherwise, if the original block has multiple predecessors, then the
47414743
// adjoint block has multiple successors. Do `switch_enum` to branch on
47424744
// the predecessor enum values to adjoint successor blocks.
47434745
else {
47444746
builder.createSwitchEnum(
4745-
adjLoc, predEnumVal, /*DefaultBB*/ nullptr, adjointSuccessorCases);
4747+
pbLoc, predEnumVal, /*DefaultBB*/ nullptr, adjointSuccessorCases);
47464748
}
47474749
}
47484750

@@ -4769,14 +4771,14 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
47694771
auto origParam = origParams[parameterIndex];
47704772
if (origParam->getType().isObject()) {
47714773
auto adjVal = getAdjointValue(origEntry, origParam);
4772-
auto val = materializeAdjointDirect(adjVal, adjLoc);
4774+
auto val = materializeAdjointDirect(adjVal, pbLoc);
47734775
if (auto *cleanup = val.getCleanup()) {
47744776
LLVM_DEBUG(getADDebugStream() << "Disabling cleanup for "
47754777
<< val.getValue() << "for return\n");
47764778
cleanup->disable();
47774779
LLVM_DEBUG(getADDebugStream() << "Applying "
47784780
<< cleanup->getNumChildren() << " child cleanups\n");
4779-
cleanup->applyRecursively(builder, adjLoc);
4781+
cleanup->applyRecursively(builder, pbLoc);
47804782
}
47814783
retElts.push_back(val);
47824784
} else {
@@ -4795,12 +4797,12 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
47954797

47964798
// Disable cleanup for original indirect parameter adjoint buffers.
47974799
// Copy them to adjoint indirect results.
4798-
assert(indParamAdjoints.size() == adjoint.getIndirectResults().size() &&
4800+
assert(indParamAdjoints.size() == pullback.getIndirectResults().size() &&
47994801
"Indirect parameter adjoint count mismatch");
4800-
for (auto pair : zip(indParamAdjoints, adjoint.getIndirectResults())) {
4802+
for (auto pair : zip(indParamAdjoints, pullback.getIndirectResults())) {
48014803
auto &source = std::get<0>(pair);
48024804
auto &dest = std::get<1>(pair);
4803-
builder.createCopyAddr(adjLoc, source, dest, IsTake, IsInitialization);
4805+
builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
48044806
if (auto *cleanup = source.getCleanup())
48054807
cleanup->disable();
48064808
}
@@ -4812,13 +4814,13 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
48124814
// Buffers should not be allocated needlessly.
48134815
assert(!alloc.getValue()->use_empty());
48144816
if (auto *cleanup = alloc.getCleanup())
4815-
cleanup->applyRecursively(builder, adjLoc);
4816-
builder.createDeallocStack(adjLoc, alloc);
4817+
cleanup->applyRecursively(builder, pbLoc);
4818+
builder.createDeallocStack(pbLoc, alloc);
48174819
}
4818-
builder.createReturn(adjLoc, joinElements(retElts, builder, adjLoc));
4820+
builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc));
48194821

4820-
LLVM_DEBUG(getADDebugStream() << "Generated adjoint for "
4821-
<< original.getName() << ":\n" << adjoint);
4822+
LLVM_DEBUG(getADDebugStream() << "Generated pullback for "
4823+
<< original.getName() << ":\n" << pullback);
48224824
return errorOccurred;
48234825
}
48244826

0 commit comments

Comments
 (0)