@@ -4004,7 +4004,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4004
4004
assert (tanFieldLookup.size () == 1 );
4005
4005
auto *tanField = cast<VarDecl>(tanFieldLookup.front ());
4006
4006
return builder.createStructElementAddr (
4007
- seai->getLoc (), adjSource.getValue (), tanField);
4007
+ seai->getLoc (), adjSource.getValue (), tanField);
4008
4008
}
4009
4009
// Handle `tuple_element_addr`.
4010
4010
if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
@@ -4649,50 +4649,45 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4649
4649
errorOccurred = true ;
4650
4650
}
4651
4651
4652
- AllocStackInst *emitDifferentiableViewSubscript (
4653
- ApplyInst *ai, SILType elType, SILValue adjointArray, SILValue fnRef,
4654
- CanGenericSignature genericSig, int index) {
4652
+ AllocStackInst *
4653
+ emitDifferentiableViewSubscript (ApplyInst *ai, SILType elType,
4654
+ SILValue adjointArray, SILValue fnRef,
4655
+ CanGenericSignature genericSig, int index) {
4655
4656
auto &ctx = builder.getASTContext ();
4656
4657
auto astType = elType.getASTType ();
4657
4658
auto literal = builder.createIntegerLiteral (
4658
4659
ai->getLoc (), SILType::getBuiltinIntegerType (64 , ctx), index);
4659
4660
auto intType = SILType::getPrimitiveObjectType (
4660
4661
ctx.getIntDecl ()->getDeclaredType ()->getCanonicalType ());
4661
- auto intStruct = builder.createStruct (
4662
- ai->getLoc (), intType, {literal});
4663
- AllocStackInst *subscriptBuffer = builder.createAllocStack (
4664
- ai->getLoc (), elType);
4662
+ auto intStruct = builder.createStruct (ai->getLoc (), intType, {literal});
4663
+ AllocStackInst *subscriptBuffer =
4664
+ builder.createAllocStack (ai->getLoc (), elType);
4665
4665
auto swiftModule = getModule ().getSwiftModule ();
4666
- auto diffProto = ctx.getProtocol (KnownProtocolKind::Differentiable);
4667
- auto diffConf = swiftModule->lookupConformance (
4668
- astType, diffProto);
4666
+ auto diffProto = ctx.getProtocol (KnownProtocolKind::Differentiable);
4667
+ auto diffConf = swiftModule->lookupConformance (astType, diffProto);
4669
4668
assert (diffConf.hasValue () && " Missing conformance to `Differentiable`" );
4670
4669
auto addArithProto = ctx.getProtocol (KnownProtocolKind::AdditiveArithmetic);
4671
- auto addArithConf = swiftModule->lookupConformance (
4672
- astType, addArithProto);
4670
+ auto addArithConf = swiftModule->lookupConformance (astType, addArithProto);
4673
4671
assert (addArithConf.hasValue () &&
4674
4672
" Missing conformance to `AdditiveArithmetic`" );
4675
- auto subMap = SubstitutionMap::get (
4676
- genericSig, {astType},
4677
- {*addArithConf, *diffConf});
4678
- auto subscriptApply = builder.createApply (
4679
- ai->getLoc (), fnRef, subMap,
4680
- {subscriptBuffer, intStruct, adjointArray});
4673
+ auto subMap =
4674
+ SubstitutionMap::get (genericSig, {astType}, {*addArithConf, *diffConf});
4675
+ builder.createApply (ai->getLoc (), fnRef, subMap,
4676
+ {subscriptBuffer, intStruct, adjointArray});
4681
4677
return subscriptBuffer;
4682
4678
}
4683
4679
4684
- void accumulateDifferentiableViewSubscriptDirect (
4685
- ApplyInst *ai, SILType elType, StoreInst *si,
4686
- AllocStackInst *subscriptBuffer) {
4680
+ void
4681
+ accumulateDifferentiableViewSubscriptDirect (ApplyInst *ai, SILType elType,
4682
+ StoreInst *si,
4683
+ AllocStackInst *subscriptBuffer) {
4687
4684
auto astType = elType.getASTType ();
4688
- auto newAdjValue = builder.createLoad (
4689
- ai->getLoc (), subscriptBuffer, getBufferLOQ (astType, getPullback ()));
4690
- addAdjointValue (
4691
- si->getParent (), si->getSrc (),
4692
- makeConcreteAdjointValue (ValueWithCleanup (
4693
- newAdjValue, makeCleanup (newAdjValue, emitCleanup))));
4694
- builder.createDeallocStack (
4695
- ai->getLoc (), subscriptBuffer);
4685
+ auto newAdjValue = builder.createLoad (ai->getLoc (), subscriptBuffer,
4686
+ getBufferLOQ (astType, getPullback ()));
4687
+ addAdjointValue (si->getParent (), si->getSrc (),
4688
+ makeConcreteAdjointValue (ValueWithCleanup (
4689
+ newAdjValue, makeCleanup (newAdjValue, emitCleanup))));
4690
+ builder.createDeallocStack (ai->getLoc (), subscriptBuffer);
4696
4691
}
4697
4692
4698
4693
void accumulateDifferentiableViewSubscriptIndirect (
@@ -4701,19 +4696,16 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
4701
4696
ai->getLoc (), subscriptBuffer, SILAccessKind::Read,
4702
4697
SILAccessEnforcement::Static, /* noNestedConflict*/ true ,
4703
4698
/* fromBuiltin*/ false );
4704
- addToAdjointBuffer (
4705
- cai->getParent (), cai->getSrc (), subscriptBufferAccess);
4706
- builder.createEndAccess (
4707
- ai->getLoc (), subscriptBufferAccess, /* aborted*/ false );
4699
+ addToAdjointBuffer (cai->getParent (), cai->getSrc (), subscriptBufferAccess);
4700
+ builder.createEndAccess (ai->getLoc (), subscriptBufferAccess,
4701
+ /* aborted*/ false );
4708
4702
builder.createDeallocStack (ai->getLoc (), subscriptBuffer);
4709
4703
}
4710
4704
4711
4705
void visitArrayInitialization (ApplyInst *ai) {
4712
4706
SILValue adjointArray;
4713
4707
SILValue fnRef;
4714
4708
CanGenericSignature genericSig;
4715
- auto lookupConformance = LookUpConformanceInModule (
4716
- getModule ().getSwiftModule ());
4717
4709
for (auto use : ai->getUses ()) {
4718
4710
auto tei = dyn_cast<TupleExtractInst>(use->getUser ()->getResult (0 ));
4719
4711
if (!tei || tei->getFieldNo () != 0 ) continue ;
@@ -5165,9 +5157,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5165
5157
}
5166
5158
}
5167
5159
5168
- // Handle `load` instruction.
5169
- // Original: y = load x
5170
- // Adjoint: adj[x] += adj[y]
5160
+ // / Handle `load` instruction.
5161
+ // / Original: y = load x
5162
+ // / Adjoint: adj[x] += adj[y]
5171
5163
void visitLoadInst (LoadInst *li) {
5172
5164
auto *bb = li->getParent ();
5173
5165
auto adjVal = materializeAdjointDirect (getAdjointValue (bb, li), li->getLoc ());
@@ -5199,9 +5191,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5199
5191
builder.createDeallocStack (li->getLoc (), localBuf);
5200
5192
}
5201
5193
5202
- // Handle `store` instruction.
5203
- // Original: store x to y
5204
- // Adjoint: adj[x] += load adj[y]; adj[y] = 0
5194
+ // / Handle `store` instruction.
5195
+ // / Original: store x to y
5196
+ // / Adjoint: adj[x] += load adj[y]; adj[y] = 0
5205
5197
void visitStoreInst (StoreInst *si) {
5206
5198
auto *bb = si->getParent ();
5207
5199
auto &adjBuf = getAdjointBuffer (bb, si->getDest ());
@@ -5222,9 +5214,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5222
5214
emitZeroIndirect (bufType.getASTType (), adjBuf, si->getLoc ());
5223
5215
}
5224
5216
5225
- // Handle `copy_addr` instruction.
5226
- // Original: copy_addr x to y
5227
- // Adjoint: adj[x] += adj[y]; adj[y] = 0
5217
+ // / Handle `copy_addr` instruction.
5218
+ // / Original: copy_addr x to y
5219
+ // / Adjoint: adj[x] += adj[y]; adj[y] = 0
5228
5220
void visitCopyAddrInst (CopyAddrInst *cai) {
5229
5221
auto *bb = cai->getParent ();
5230
5222
auto &adjDest = getAdjointBuffer (bb, cai->getDest ());
@@ -5250,9 +5242,9 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5250
5242
adjDest.setCleanup (cleanup);
5251
5243
}
5252
5244
5253
- // Handle `begin_access` instruction.
5254
- // Original: y = begin_access x
5255
- // Adjoint: nothing
5245
+ // / Handle `begin_access` instruction.
5246
+ // / Original: y = begin_access x
5247
+ // / Adjoint: nothing (differentiability checks, cleanup propagation)
5256
5248
void visitBeginAccessInst (BeginAccessInst *bai) {
5257
5249
// Check for non-differentiable writes.
5258
5250
if (bai->getAccessKind () == SILAccessKind::Modify) {
@@ -5291,7 +5283,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
5291
5283
#define NOT_DIFFERENTIABLE (INST, DIAG ) \
5292
5284
void visit##INST##Inst(INST##Inst *inst) { \
5293
5285
getContext ().emitNondifferentiabilityError ( \
5294
- inst, getDifferentiationTask (), DIAG); \
5286
+ inst, getInvoker (), DIAG); \
5295
5287
errorOccurred = true ; \
5296
5288
return ; \
5297
5289
}
0 commit comments