@@ -1741,6 +1741,7 @@ static void dumpActivityInfo(SILFunction &fn,
1741
1741
llvm::raw_ostream &s = llvm::dbgs()) {
1742
1742
s << " Activity info for " << fn.getName () << " at " << indices << ' \n ' ;
1743
1743
for (auto &bb : fn) {
1744
+ s << " bb" << bb.getDebugID () << " :\n " ;
1744
1745
for (auto *arg : bb.getArguments ())
1745
1746
dumpActivityInfo (arg, indices, activityInfo, s);
1746
1747
for (auto &inst : bb)
@@ -3192,7 +3193,7 @@ class VJPEmitter final
3192
3193
// Do standard cloning.
3193
3194
if (!hasActiveResults || !hasActiveArguments) {
3194
3195
LLVM_DEBUG (getADDebugStream () << " No active results:\n " << *ai << ' \n ' );
3195
- SILClonerWithScopes ::visitApplyInst (ai);
3196
+ TypeSubstCloner ::visitApplyInst (ai);
3196
3197
return ;
3197
3198
}
3198
3199
@@ -3893,11 +3894,10 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3893
3894
LLVM_DEBUG (getADDebugStream () << " Adding adjoint for " << originalValue);
3894
3895
#ifndef NDEBUG
3895
3896
auto origTy = remapType (originalValue->getType ()).getASTType ();
3896
- auto tanSpace = origTy->getAutoDiffAssociatedTangentSpace (
3897
- LookUpConformanceInModule (getModule ().getSwiftModule ()));
3897
+ auto tangentSpace = getTangentSpace (origTy);
3898
3898
// The adjoint value must be in the tangent space.
3899
- assert (tanSpace && newAdjointValue.getType ().getASTType ()->isEqual (
3900
- tanSpace ->getCanonicalType ()));
3899
+ assert (tangentSpace && newAdjointValue.getType ().getASTType ()->isEqual (
3900
+ tangentSpace ->getCanonicalType ()));
3901
3901
#endif
3902
3902
auto insertion =
3903
3903
valueMap.try_emplace ({origBB, originalValue}, newAdjointValue);
@@ -4662,11 +4662,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4662
4662
}
4663
4663
4664
4664
// / Handle `struct` instruction.
4665
- // / y = struct (x0, x1, x2, ...)
4666
- // / adj[x0] += struct_extract adj[y], #x0
4667
- // / adj[x1] += struct_extract adj[y], #x1
4668
- // / adj[x2] += struct_extract adj[y], #x2
4669
- // / ...
4665
+ // / Original: y = struct (x0, x1, x2, ...)
4666
+ // / Adjoint: adj[x0] += struct_extract adj[y], #x0
4667
+ // / adj[x1] += struct_extract adj[y], #x1
4668
+ // / adj[x2] += struct_extract adj[y], #x2
4669
+ // / ...
4670
4670
void visitStructInst (StructInst *si) {
4671
4671
auto *bb = si->getParent ();
4672
4672
auto loc = si->getLoc ();
@@ -4684,9 +4684,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4684
4684
auto adjStruct = materializeAdjointDirect (std::move (av), loc);
4685
4685
// Find the struct `TangentVector` type.
4686
4686
auto structTy = remapType (si->getType ()).getASTType ();
4687
- auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace (
4688
- LookUpConformanceInModule (getModule ().getSwiftModule ()))
4689
- ->getType ()->getCanonicalType ();
4687
+ auto tangentVectorTy =
4688
+ getTangentSpace (structTy)->getType ()->getCanonicalType ();
4690
4689
assert (!getModule ().Types .getTypeLowering (
4691
4690
tangentVectorTy, ResilienceExpansion::Minimal)
4692
4691
.isAddressOnly ());
@@ -4737,20 +4736,19 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4737
4736
}
4738
4737
}
4739
4738
4739
+ // / Handle `struct_extract` instruction.
4740
+ // / Original: y = struct_extract x, #field
4741
+ // / Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
4742
+ // / ^~~~~~~
4743
+ // / field in tangent space corresponding to #field
4740
4744
void visitStructExtractInst (StructExtractInst *sei) {
4741
4745
assert (!sei->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>() &&
4742
4746
" `struct_extract` with `@noDerivative` field should not be "
4743
4747
" differentiated; activity analysis should not marked as varied" );
4744
- // Compute adjoint as follows:
4745
- // y = struct_extract x, #key
4746
- // adj[x] += struct (0, ..., #key': adj[y], ..., 0)
4747
- // where `#key'` is the field in the tangent space corresponding to
4748
- // `#key`.
4749
4748
auto *bb = sei->getParent ();
4750
4749
auto structTy = remapType (sei->getOperand ()->getType ()).getASTType ();
4751
- auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace (
4752
- LookUpConformanceInModule (getModule ().getSwiftModule ()))
4753
- ->getType ()->getCanonicalType ();
4750
+ auto tangentVectorTy =
4751
+ getTangentSpace (structTy)->getType ()->getCanonicalType ();
4754
4752
assert (!getModule ().Types .getTypeLowering (
4755
4753
tangentVectorTy, ResilienceExpansion::Minimal)
4756
4754
.isAddressOnly ());
@@ -4810,9 +4808,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4810
4808
}
4811
4809
4812
4810
// / Handle `tuple` instruction.
4813
- // / y = tuple (x0, x1, x2, ...)
4814
- // / adj[x0] += tuple_extract adj[y], 0
4815
- // / ...
4811
+ // / Original: y = tuple (x0, x1, x2, ...)
4812
+ // / Adjoint: adj[x0] += tuple_extract adj[y], 0
4813
+ // / ...
4816
4814
void visitTupleInst (TupleInst *ti) {
4817
4815
auto *bb = ti->getParent ();
4818
4816
auto av = getAdjointValue (bb, ti);
@@ -4851,9 +4849,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4851
4849
}
4852
4850
4853
4851
// / Handle `tuple_extract` instruction.
4854
- // / y = tuple_extract x, <n>
4855
- // / |--- n-th element
4856
- // / adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
4852
+ // / Original: y = tuple_extract x, <n>
4853
+ // / Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
4854
+ // / ^~~~~~
4855
+ // / n'-th element, where n' is tuple tangent space
4856
+ // / index corresponding to n
4857
4857
void visitTupleExtractInst (TupleExtractInst *tei) {
4858
4858
auto *bb = tei->getParent ();
4859
4859
auto tupleTanTy = getRemappedTangentType (tei->getOperand ()->getType ());
@@ -5234,9 +5234,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
5234
5234
5235
5235
void AdjointEmitter::emitZeroIndirect (CanType type, SILValue bufferAccess,
5236
5236
SILLocation loc) {
5237
- auto *swiftMod = getModule ().getSwiftModule ();
5238
- auto tangentSpace = type->getAutoDiffAssociatedTangentSpace (
5239
- LookUpConformanceInModule (swiftMod));
5237
+ auto tangentSpace = getTangentSpace (type);
5240
5238
assert (tangentSpace && " No tangent space for this type" );
5241
5239
switch (tangentSpace->getKind ()) {
5242
5240
case VectorSpace::Kind::Vector:
@@ -5371,9 +5369,7 @@ SILValue AdjointEmitter::accumulateDirect(SILValue lhs, SILValue rhs) {
5371
5369
auto adjointTy = lhs->getType ();
5372
5370
auto adjointASTTy = adjointTy.getASTType ();
5373
5371
auto loc = lhs.getLoc ();
5374
- auto *swiftMod = getModule ().getSwiftModule ();
5375
- auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace (
5376
- LookUpConformanceInModule (swiftMod));
5372
+ auto tangentSpace = getTangentSpace (adjointASTTy);
5377
5373
assert (tangentSpace && " No tangent space for this type" );
5378
5374
switch (tangentSpace->getKind ()) {
5379
5375
case VectorSpace::Kind::Vector: {
0 commit comments