Skip to content

Commit 6364099

Browse files
authored
---
yaml --- r: 340982 b: refs/heads/rxwei-patch-1 c: 6b3d874 h: refs/heads/master
1 parent 22a6926 commit 6364099

File tree

2 files changed

+42
-50
lines changed

2 files changed

+42
-50
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 73984ad36b531fdecebe9d51644e6cdb34afc82c
1018+
refs/heads/rxwei-patch-1: 6b3d8742e82ca0b76100ad6a802ffcb8e2bf11ee
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,7 +4004,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
40044004
assert(tanFieldLookup.size() == 1);
40054005
auto *tanField = cast<VarDecl>(tanFieldLookup.front());
40064006
return builder.createStructElementAddr(
4007-
seai->getLoc(), adjSource.getValue(), tanField);
4007+
seai->getLoc(), adjSource.getValue(), tanField);
40084008
}
40094009
// Handle `tuple_element_addr`.
40104010
if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
@@ -4649,50 +4649,45 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
46494649
errorOccurred = true;
46504650
}
46514651

4652-
AllocStackInst *emitDifferentiableViewSubscript(
4653-
ApplyInst *ai, SILType elType, SILValue adjointArray, SILValue fnRef,
4654-
CanGenericSignature genericSig, int index) {
4652+
AllocStackInst *
4653+
emitDifferentiableViewSubscript(ApplyInst *ai, SILType elType,
4654+
SILValue adjointArray, SILValue fnRef,
4655+
CanGenericSignature genericSig, int index) {
46554656
auto &ctx = builder.getASTContext();
46564657
auto astType = elType.getASTType();
46574658
auto literal = builder.createIntegerLiteral(
46584659
ai->getLoc(), SILType::getBuiltinIntegerType(64, ctx), index);
46594660
auto intType = SILType::getPrimitiveObjectType(
46604661
ctx.getIntDecl()->getDeclaredType()->getCanonicalType());
4661-
auto intStruct = builder.createStruct(
4662-
ai->getLoc(), intType, {literal});
4663-
AllocStackInst *subscriptBuffer = builder.createAllocStack(
4664-
ai->getLoc(), elType);
4662+
auto intStruct = builder.createStruct(ai->getLoc(), intType, {literal});
4663+
AllocStackInst *subscriptBuffer =
4664+
builder.createAllocStack(ai->getLoc(), elType);
46654665
auto swiftModule = getModule().getSwiftModule();
4666-
auto diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
4667-
auto diffConf = swiftModule->lookupConformance(
4668-
astType, diffProto);
4666+
auto diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
4667+
auto diffConf = swiftModule->lookupConformance(astType, diffProto);
46694668
assert(diffConf.hasValue() && "Missing conformance to `Differentiable`");
46704669
auto addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
4671-
auto addArithConf = swiftModule->lookupConformance(
4672-
astType, addArithProto);
4670+
auto addArithConf = swiftModule->lookupConformance(astType, addArithProto);
46734671
assert(addArithConf.hasValue() &&
46744672
"Missing conformance to `AdditiveArithmetic`");
4675-
auto subMap = SubstitutionMap::get(
4676-
genericSig, {astType},
4677-
{*addArithConf, *diffConf});
4678-
auto subscriptApply = builder.createApply(
4679-
ai->getLoc(), fnRef, subMap,
4680-
{subscriptBuffer, intStruct, adjointArray});
4673+
auto subMap =
4674+
SubstitutionMap::get(genericSig, {astType}, {*addArithConf, *diffConf});
4675+
builder.createApply(ai->getLoc(), fnRef, subMap,
4676+
{subscriptBuffer, intStruct, adjointArray});
46814677
return subscriptBuffer;
46824678
}
46834679

4684-
void accumulateDifferentiableViewSubscriptDirect(
4685-
ApplyInst *ai, SILType elType, StoreInst *si,
4686-
AllocStackInst *subscriptBuffer) {
4680+
void
4681+
accumulateDifferentiableViewSubscriptDirect(ApplyInst *ai, SILType elType,
4682+
StoreInst *si,
4683+
AllocStackInst *subscriptBuffer) {
46874684
auto astType = elType.getASTType();
4688-
auto newAdjValue = builder.createLoad(
4689-
ai->getLoc(), subscriptBuffer, getBufferLOQ(astType, getPullback()));
4690-
addAdjointValue(
4691-
si->getParent(), si->getSrc(),
4692-
makeConcreteAdjointValue(ValueWithCleanup(
4693-
newAdjValue, makeCleanup(newAdjValue, emitCleanup))));
4694-
builder.createDeallocStack(
4695-
ai->getLoc(), subscriptBuffer);
4685+
auto newAdjValue = builder.createLoad(ai->getLoc(), subscriptBuffer,
4686+
getBufferLOQ(astType, getPullback()));
4687+
addAdjointValue(si->getParent(), si->getSrc(),
4688+
makeConcreteAdjointValue(ValueWithCleanup(
4689+
newAdjValue, makeCleanup(newAdjValue, emitCleanup))));
4690+
builder.createDeallocStack(ai->getLoc(), subscriptBuffer);
46964691
}
46974692

46984693
void accumulateDifferentiableViewSubscriptIndirect(
@@ -4701,19 +4696,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
47014696
ai->getLoc(), subscriptBuffer, SILAccessKind::Read,
47024697
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
47034698
/*fromBuiltin*/ false);
4704-
addToAdjointBuffer(
4705-
cai->getParent(), cai->getSrc(), subscriptBufferAccess);
4706-
builder.createEndAccess(
4707-
ai->getLoc(), subscriptBufferAccess, /*aborted*/ false);
4699+
addToAdjointBuffer(cai->getParent(), cai->getSrc(), subscriptBufferAccess);
4700+
builder.createEndAccess(ai->getLoc(), subscriptBufferAccess,
4701+
/*aborted*/ false);
47084702
builder.createDeallocStack(ai->getLoc(), subscriptBuffer);
47094703
}
47104704

47114705
void visitArrayInitialization(ApplyInst *ai) {
47124706
SILValue adjointArray;
47134707
SILValue fnRef;
47144708
CanGenericSignature genericSig;
4715-
auto lookupConformance = LookUpConformanceInModule(
4716-
getModule().getSwiftModule());
47174709
for (auto use : ai->getUses()) {
47184710
auto tei = dyn_cast<TupleExtractInst>(use->getUser()->getResult(0));
47194711
if (!tei || tei->getFieldNo() != 0) continue;
@@ -5165,9 +5157,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
51655157
}
51665158
}
51675159

5168-
// Handle `load` instruction.
5169-
// Original: y = load x
5170-
// Adjoint: adj[x] += adj[y]
5160+
/// Handle `load` instruction.
5161+
/// Original: y = load x
5162+
/// Adjoint: adj[x] += adj[y]
51715163
void visitLoadInst(LoadInst *li) {
51725164
auto *bb = li->getParent();
51735165
auto adjVal = materializeAdjointDirect(getAdjointValue(bb, li), li->getLoc());
@@ -5199,9 +5191,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
51995191
builder.createDeallocStack(li->getLoc(), localBuf);
52005192
}
52015193

5202-
// Handle `store` instruction.
5203-
// Original: store x to y
5204-
// Adjoint: adj[x] += load adj[y]; adj[y] = 0
5194+
/// Handle `store` instruction.
5195+
/// Original: store x to y
5196+
/// Adjoint: adj[x] += load adj[y]; adj[y] = 0
52055197
void visitStoreInst(StoreInst *si) {
52065198
auto *bb = si->getParent();
52075199
auto &adjBuf = getAdjointBuffer(bb, si->getDest());
@@ -5222,9 +5214,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
52225214
emitZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
52235215
}
52245216

5225-
// Handle `copy_addr` instruction.
5226-
// Original: copy_addr x to y
5227-
// Adjoint: adj[x] += adj[y]; adj[y] = 0
5217+
/// Handle `copy_addr` instruction.
5218+
/// Original: copy_addr x to y
5219+
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
52285220
void visitCopyAddrInst(CopyAddrInst *cai) {
52295221
auto *bb = cai->getParent();
52305222
auto &adjDest = getAdjointBuffer(bb, cai->getDest());
@@ -5250,9 +5242,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
52505242
adjDest.setCleanup(cleanup);
52515243
}
52525244

5253-
// Handle `begin_access` instruction.
5254-
// Original: y = begin_access x
5255-
// Adjoint: nothing
5245+
/// Handle `begin_access` instruction.
5246+
/// Original: y = begin_access x
5247+
/// Adjoint: nothing (differentiability checks, cleanup propagation)
52565248
void visitBeginAccessInst(BeginAccessInst *bai) {
52575249
// Check for non-differentiable writes.
52585250
if (bai->getAccessKind() == SILAccessKind::Modify) {
@@ -5291,7 +5283,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
52915283
#define NOT_DIFFERENTIABLE(INST, DIAG) \
52925284
void visit##INST##Inst(INST##Inst *inst) { \
52935285
getContext().emitNondifferentiabilityError( \
5294-
inst, getDifferentiationTask(), DIAG); \
5286+
inst, getInvoker(), DIAG); \
52955287
errorOccurred = true; \
52965288
return; \
52975289
}

0 commit comments

Comments
 (0)