@@ -669,6 +669,24 @@ struct NestedApplyActivity {
669
669
SILAutoDiffIndices indices;
670
670
};
671
671
672
+ // / Specifies how we should differentiate a `struct_extract` instruction.
673
+ enum class StructExtractDifferentiationStrategy {
674
+ // The `struct_extract` is not active, so do not differentiate it.
675
+ Inactive,
676
+
677
+ // The `struct_extract` is extracting a field from a Differentiable struct
678
+ // with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the
679
+ // `struct_extract` by setting the adjoint to a vector in the cotangent space
680
+ // that is zero except along the direction of the corresponding field.
681
+ //
682
+ // Fields correspond by matching name.
683
+ FieldwiseProductSpace,
684
+
685
+ // Differentiate the `struct_extract` by looking up the corresponding getter
686
+ // and using its VJP.
687
+ Getter
688
+ };
689
+
672
690
// / A differentiation task, specifying the original function and the
673
691
// / `[differentiable]` attribute on the function. PrimalGen and AdjointGen
674
692
// / will synthesize the primal and the adjoint for this task, filling the primal
@@ -714,6 +732,10 @@ class DifferentiationTask {
714
732
// / Note: This is only used when `DifferentiationUseVJP`.
715
733
DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;
716
734
735
+ // / Mapping from original `struct_extract` instructions to their strategies.
736
+ DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy>
737
+ structExtractDifferentiationStrategies;
738
+
717
739
// / Cache for associated functions.
718
740
SILFunction *primal = nullptr ;
719
741
SILFunction *adjoint = nullptr ;
@@ -810,6 +832,11 @@ class DifferentiationTask {
810
832
return nestedApplyActivities;
811
833
}
812
834
835
+ DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy> &
836
+ getStructExtractDifferentiationStrategies () {
837
+ return structExtractDifferentiationStrategies;
838
+ }
839
+
813
840
bool isEqual (const DifferentiationTask &other) const {
814
841
return original == other.original && attr == other.attr ;
815
842
}
@@ -2228,16 +2255,42 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2228
2255
}
2229
2256
2230
2257
void visitStructExtractInst (StructExtractInst *sei) {
2258
+ auto &astCtx = getContext ().getASTContext ();
2259
+ auto &structExtractDifferentiationStrategies =
2260
+ getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
2261
+
2231
2262
// Special handling logic only applies when the `struct_extract` is active.
2232
2263
// If not, just do standard cloning.
2233
2264
if (!activityInfo.isActive (sei, synthesis.indices )) {
2234
2265
LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *sei << ' \n ' );
2266
+ structExtractDifferentiationStrategies.insert (
2267
+ {sei, StructExtractDifferentiationStrategy::Inactive});
2235
2268
SILClonerWithScopes::visitStructExtractInst (sei);
2236
2269
return ;
2237
2270
}
2238
2271
2239
- // This instruction is active. Replace it with a call to the corresponding
2240
- // getter's VJP.
2272
+ // This instruction is active. Determine the appropriate differentiation
2273
+ // strategy, and use it.
2274
+
2275
+ // Use the FieldwiseProductSpace strategy, if appropriate.
2276
+ auto *structDecl = sei->getStructDecl ();
2277
+ auto aliasLookup = structDecl->lookupDirect (astCtx.Id_CotangentVector );
2278
+ if (aliasLookup.size () >= 1 ) {
2279
+ assert (aliasLookup.size () == 1 );
2280
+ assert (isa<TypeAliasDecl>(aliasLookup[0 ]));
2281
+ auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0 ]);
2282
+ if (aliasDecl->getAttrs ().hasAttribute <FieldwiseProductSpaceAttr>()) {
2283
+ structExtractDifferentiationStrategies.insert (
2284
+ {sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
2285
+ SILClonerWithScopes::visitStructExtractInst (sei);
2286
+ return ;
2287
+ }
2288
+ }
2289
+
2290
+ // The FieldwiseProductSpace strategy is not appropriate, so use the Getter
2291
+ // strategy.
2292
+ structExtractDifferentiationStrategies.insert (
2293
+ {sei, StructExtractDifferentiationStrategy::Getter});
2241
2294
2242
2295
// Find the corresponding getter and its VJP.
2243
2296
auto *getterDecl = sei->getField ()->getGetter ();
@@ -3596,17 +3649,103 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3596
3649
}
3597
3650
3598
3651
void visitStructExtractInst (StructExtractInst *sei) {
3599
- // Replace a `struct_extract` with a call to its pullback.
3600
3652
auto loc = remapLocation (sei->getLoc ());
3653
+ auto &astCtx = getContext ().getASTContext ();
3601
3654
3602
- // Get the pullback.
3603
- auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3604
- if (!pullbackField) {
3605
- // Inactive `struct_extract` instructions don't need to be cloned into the
3606
- // adjoint.
3655
+ auto &differentiationStrategies =
3656
+ getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
3657
+ auto differentiationStrategyLookUp = differentiationStrategies.find (sei);
3658
+ assert (differentiationStrategyLookUp != differentiationStrategies.end ());
3659
+ auto differentiationStrategy = differentiationStrategyLookUp->second ;
3660
+
3661
+ if (differentiationStrategy ==
3662
+ StructExtractDifferentiationStrategy::Inactive) {
3607
3663
assert (!activityInfo.isActive (sei, synthesis.indices ));
3608
3664
return ;
3609
3665
}
3666
+
3667
+ if (differentiationStrategy ==
3668
+ StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
3669
+ // Compute adjoint as follows:
3670
+ // y = struct_extract <key>, x
3671
+ // adj[x] = struct (0, ..., key': adj[y], ..., 0)
3672
+ // where `key'` is the field in the cotangent space corresponding to
3673
+ // `key`.
3674
+
3675
+ // Find the decl of the cotangent space type.
3676
+ auto *structDecl = sei->getStructDecl ();
3677
+ auto aliasLookup = structDecl->lookupDirect (astCtx.Id_CotangentVector );
3678
+ assert (aliasLookup.size () == 1 );
3679
+ assert (isa<TypeAliasDecl>(aliasLookup[0 ]));
3680
+ auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0 ]);
3681
+ assert (aliasDecl->getAttrs ().hasAttribute <FieldwiseProductSpaceAttr>());
3682
+ auto cotangentVectorTy =
3683
+ aliasDecl->getUnderlyingTypeLoc ().getType ()->getCanonicalType ();
3684
+ assert (!getModule ()
3685
+ .Types .getTypeLowering (cotangentVectorTy)
3686
+ .isAddressOnly ());
3687
+ auto cotangentVectorSILTy =
3688
+ SILType::getPrimitiveObjectType (cotangentVectorTy);
3689
+ auto *cotangentVectorDecl =
3690
+ cotangentVectorTy->getStructOrBoundGenericStruct ();
3691
+ assert (cotangentVectorDecl);
3692
+
3693
+ // Find the corresponding field in the cotangent space.
3694
+ VarDecl *correspondingField = nullptr ;
3695
+ if (cotangentVectorDecl == structDecl)
3696
+ correspondingField = sei->getField ();
3697
+ else {
3698
+ auto correspondingFieldLookup =
3699
+ cotangentVectorDecl->lookupDirect (sei->getField ()->getName ());
3700
+ assert (correspondingFieldLookup.size () == 1 );
3701
+ assert (isa<VarDecl>(correspondingFieldLookup[0 ]));
3702
+ correspondingField = cast<VarDecl>(correspondingFieldLookup[0 ]);
3703
+ }
3704
+ assert (correspondingField);
3705
+
3706
+ #ifndef NDEBUG
3707
+ unsigned numMatchingStoredProperties = 0 ;
3708
+ for (auto *storedProperty : cotangentVectorDecl->getStoredProperties ())
3709
+ if (storedProperty == correspondingField)
3710
+ numMatchingStoredProperties += 1 ;
3711
+ assert (numMatchingStoredProperties == 1 );
3712
+ #endif
3713
+
3714
+ // Compute adjoint.
3715
+ auto av = getAdjointValue (sei);
3716
+ switch (av.getKind ()) {
3717
+ case AdjointValue::Kind::Zero:
3718
+ addAdjointValue (sei->getOperand (),
3719
+ AdjointValue::getZero (cotangentVectorSILTy));
3720
+ break ;
3721
+ case AdjointValue::Kind::Materialized:
3722
+ case AdjointValue::Kind::Aggregate: {
3723
+ SmallVector<AdjointValue, 8 > eltVals;
3724
+ for (auto *field : cotangentVectorDecl->getStoredProperties ()) {
3725
+ if (field == correspondingField)
3726
+ eltVals.push_back (av);
3727
+ else
3728
+ eltVals.push_back (
3729
+ AdjointValue::getZero (SILType::getPrimitiveObjectType (
3730
+ field->getType ()->getCanonicalType ())));
3731
+ }
3732
+ addAdjointValue (sei->getOperand (),
3733
+ AdjointValue::getAggregate (cotangentVectorSILTy,
3734
+ eltVals, allocator));
3735
+ }
3736
+ }
3737
+
3738
+ return ;
3739
+ }
3740
+
3741
+ // The only remaining strategy is the getter strategy.
3742
+ // Replace the `struct_extract` with a call to its pullback.
3743
+ assert (differentiationStrategy ==
3744
+ StructExtractDifferentiationStrategy::Getter);
3745
+
3746
+ // Get the pullback.
3747
+ auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3748
+ assert (pullbackField);
3610
3749
SILValue pullback = builder.createStructExtract (loc,
3611
3750
primalValueAggregateInAdj,
3612
3751
pullbackField);
0 commit comments