Skip to content

Commit 3b27ffb

Browse files
authored
---
yaml --- r: 294867 b: refs/heads/tensorflow c: 5ac83e2 h: refs/heads/master i: 294865: a99532e 294863: 4d69ce8
1 parent cb9cc02 commit 3b27ffb

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
816816
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
817817
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
818818
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
819-
refs/heads/tensorflow: 1c1305ca1059a45e4f5a90bedbdb008ea0a10ab2
819+
refs/heads/tensorflow: 5ac83e25b48a79ed6a093a7e32a98a52fc138836
820820
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
821821
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4131,6 +4131,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41314131
: nullptr;
41324132
Lowering::GenericContextScope genericContextScope(
41334133
getContext().getTypeConverter(), adjGenSig);
4134+
auto origExitIt = original.findReturnBB();
4135+
assert(origExitIt != original.end() &&
4136+
"Functions without returns must have been diagnosed");
4137+
auto *origExit = &*origExitIt;
41344138

41354139
// Get dominated active values in original blocks.
41364140
// Adjoint values of dominated active values are passed as adjoint block
@@ -4175,12 +4179,22 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
41754179
for (auto *origBB : postOrderInfo->getPostOrder()) {
41764180
auto *adjointBB = adjoint.createBasicBlock();
41774181
adjointBBMap.insert({origBB, adjointBB});
4178-
// If adjoint block is the adjoint entry, continue.
4179-
if (adjointBB->isEntry())
4180-
continue;
4181-
// Otherwise, add a pullback struct argument to the adjoint block.
41824182
auto pbStructLoweredType =
41834183
remapType(getPullbackInfo().getPullbackStructLoweredType(origBB));
4184+
// If the BB is the original exit, then the adjoint block that we just
4185+
// created must be the adjoint function's entry. For the adjoint entry,
4186+
// create entry arguments and continue to the next block.
4187+
if (origBB == origExit) {
4188+
assert(adjointBB->isEntry());
4189+
createEntryArguments(&getAdjoint());
4190+
auto *lastArg = adjointBB->getArguments().back();
4191+
assert(lastArg->getType() == pbStructLoweredType);
4192+
adjointPullbackStructArguments[origBB] = lastArg;
4193+
continue;
4194+
}
4195+
4196+
// Otherwise, we create a phi argument for the corresponding pullback
4197+
// struct, and handle dominated active values/buffers.
41844198
auto *pbStructArg = adjointBB->createPhiArgument(
41854199
pbStructLoweredType, ValueOwnershipKind::Guaranteed);
41864200
adjointPullbackStructArguments[origBB] = pbStructArg;
@@ -4222,15 +4236,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42224236
}
42234237
}
42244238

4225-
auto *origEntry = original.getEntryBlock();
4226-
auto *origExit = &*original.findReturnBB();
42274239
auto *adjointEntry = adjoint.getEntryBlock();
4228-
createEntryArguments(&adjoint);
42294240
// The adjoint function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
42304241
auto adjParamArgs = adjoint.getArgumentsWithoutIndirectResults();
42314242
assert(adjParamArgs.size() == 2);
42324243
seed = adjParamArgs[0];
4233-
adjointPullbackStructArguments[origExit] = adjParamArgs[1];
42344244

42354245
// Assign adjoint for original result.
42364246
SmallVector<SILValue, 8> origFormalResults;
@@ -4434,6 +4444,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
44344444
// Place the builder at the adjoint exit, i.e. the adjoint block
44354445
// corresponding to the original entry. Return the adjoints wrt parameters
44364446
// in the adjoint exit.
4447+
auto *origEntry = getOriginal().getEntryBlock();
44374448
builder.setInsertionPoint(getAdjointBlock(origEntry));
44384449

44394450
// This vector will contain all the materialized return elements.

0 commit comments

Comments
 (0)