@@ -252,6 +252,21 @@ static LoadOwnershipQualifier getBufferLOQ(Type type, SILFunction &fn) {
252
252
return LoadOwnershipQualifier::Unqualified;
253
253
}
254
254
255
+ // / Assuming the given type conforms to `Differentiable`, returns the associated
256
+ // / cotangent space type.
257
+ static SILType getCotangentType (CanType type, SILModule &mod) {
258
+ return SILType::getPrimitiveObjectType (
259
+ type->getAutoDiffAssociatedVectorSpace (
260
+ AutoDiffAssociatedVectorSpaceKind::Cotangent,
261
+ LookUpConformanceInModule (mod.getSwiftModule ()))->getCanonicalType ());
262
+ }
263
+
264
+ // / Assuming the given type conforms to `Differentiable`, returns the associated
265
+ // / cotangent space type.
266
+ static SILType getCotangentType (SILType type, SILModule &mod) {
267
+ return getCotangentType (type.getASTType (), mod);
268
+ }
269
+
255
270
// ===----------------------------------------------------------------------===//
256
271
// Auxiliary data structures
257
272
// ===----------------------------------------------------------------------===//
@@ -2891,7 +2906,6 @@ class AdjointValue {
2891
2906
private:
2892
2907
static bool isLegalAggregate (ArrayRef<AdjointValue> elements, SILType type) {
2893
2908
if (auto *structDecl = type.getASTType ()->getStructOrBoundGenericStruct ()) {
2894
- // TODO: Check whether this struct is @_fixed_layout and ABI public.
2895
2909
for (auto pair : llvm::zip (structDecl->getStoredProperties (), elements))
2896
2910
if (!std::get<0 >(pair)->getType ()->getCanonicalType ()
2897
2911
->isEqual (std::get<1 >(pair).getSwiftType ()))
@@ -3130,7 +3144,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3130
3144
AdjointValue getAdjointValue (SILValue originalValue) {
3131
3145
assert (originalValue->getFunction () == &getOriginal ());
3132
3146
auto insertion = adjointMap.try_emplace (
3133
- originalValue, AdjointValue::getZero (originalValue->getType ()));
3147
+ originalValue, AdjointValue::getZero (
3148
+ getCotangentType (originalValue->getType (), getModule ())));
3134
3149
return insertion.first ->getSecond ();
3135
3150
}
3136
3151
@@ -3140,6 +3155,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3140
3155
AdjointValue adjointValue) {
3141
3156
assert (originalValue->getFunction () == &getOriginal ());
3142
3157
LLVM_DEBUG (getADDebugStream () << " Adding adjoint for " << originalValue);
3158
+ #ifndef NDEBUG
3159
+ auto origTy = originalValue->getType ().getASTType ();
3160
+ auto cotanSpace = origTy->getAutoDiffAssociatedVectorSpace (
3161
+ AutoDiffAssociatedVectorSpaceKind::Cotangent,
3162
+ LookUpConformanceInModule (getModule ().getSwiftModule ()));
3163
+ // The adjoint value must be in the cotangent space.
3164
+ assert (cotanSpace && adjointValue.getType ().getASTType ()
3165
+ == cotanSpace->getCanonicalType ());
3166
+ #endif
3143
3167
auto insertion = adjointMap.try_emplace (originalValue, adjointValue);
3144
3168
auto inserted = insertion.second ;
3145
3169
auto &value = insertion.first ->getSecond ();
@@ -3644,25 +3668,32 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3644
3668
void visitStructInst (StructInst *si) {
3645
3669
auto *decl = si->getStructDecl ();
3646
3670
auto av = getAdjointValue (si);
3647
- auto loc = si->getLoc ();
3648
3671
switch (av.getKind ()) {
3649
3672
case AdjointValue::Zero:
3650
3673
for (auto *field : decl->getStoredProperties ()) {
3651
3674
auto fv = si->getFieldValue (field);
3652
- addAdjointValue (fv, AdjointValue::getZero (fv->getType ()));
3675
+ addAdjointValue (
3676
+ fv, AdjointValue::getZero (getCotangentType (fv->getType (),
3677
+ getModule ())));
3653
3678
}
3654
3679
break ;
3655
3680
case AdjointValue::Materialized: {
3656
- auto adjY = av.getMaterializedValue ();
3657
- for (auto *field : decl->getStoredProperties ())
3658
- addAdjointValue (si->getFieldValue (field),
3659
- builder.createStructExtract (loc, adjY, field));
3660
- break ;
3681
+ // FIXME(SR-9602): If `CotangentVector` is not marked
3682
+ // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
3683
+ // auto adjY = av.getMaterializedValue();
3684
+ // for (auto *field : decl->getStoredProperties())
3685
+ // addAdjointValue(si->getFieldValue(field),
3686
+ // builder.createStructExtract(loc, adjY, field));
3687
+ llvm_unreachable (" Unhandled. Are you trying to differentiate a "
3688
+ " memberwise initializer?" );
3661
3689
}
3662
3690
case AdjointValue::Aggregate: {
3663
- for (auto pair : llvm::zip (si->getElements (), av.getAggregateElements ()))
3664
- addAdjointValue (std::get<0 >(pair), std::get<1 >(pair));
3665
- break ;
3691
+ // FIXME(SR-9602): If `CotangentVector` is not marked
3692
+ // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
3693
+ // for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements()))
3694
+ // addAdjointValue(std::get<0>(pair), std::get<1>(pair));
3695
+ llvm_unreachable (" Unhandled. Are you trying to differentiate a "
3696
+ " memberwise initializer?" );
3666
3697
}
3667
3698
}
3668
3699
}
@@ -3739,9 +3770,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3739
3770
if (field == correspondingField)
3740
3771
eltVals.push_back (av);
3741
3772
else
3742
- eltVals.push_back (
3743
- AdjointValue::getZero ( SILType::getPrimitiveObjectType (
3744
- field-> getType ()-> getCanonicalType ())));
3773
+ eltVals.push_back (AdjointValue::getZero (
3774
+ getCotangentType (field-> getType ()-> getCanonicalType (),
3775
+ getModule ())));
3745
3776
}
3746
3777
addAdjointValue (sei->getOperand (),
3747
3778
AdjointValue::getAggregate (cotangentVectorSILTy,
@@ -3789,12 +3820,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3789
3820
switch (av.getKind ()) {
3790
3821
case AdjointValue::Kind::Zero:
3791
3822
for (auto eltVal : ti->getElements ())
3792
- addAdjointValue (eltVal, AdjointValue::getZero (eltVal->getType ()));
3823
+ addAdjointValue (eltVal,
3824
+ AdjointValue::getZero (getCotangentType (eltVal->getType (),
3825
+ getModule ())));
3793
3826
break ;
3794
3827
case AdjointValue::Kind::Materialized:
3795
3828
for (auto i : range (ti->getNumOperands ()))
3796
3829
addAdjointValue (ti->getOperand (i),
3797
- builder.createTupleExtract (ti->getLoc (), ti, i));
3830
+ builder.createTupleExtract (
3831
+ ti->getLoc (), av.getMaterializedValue (), i));
3798
3832
break ;
3799
3833
case AdjointValue::Kind::Aggregate:
3800
3834
for (auto pair : llvm::zip (ti->getElements (), av.getAggregateElements ()))
@@ -3809,28 +3843,26 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3809
3843
// / adj[x] = tuple (0, 0, ..., adj[y], ..., 0, 0)
3810
3844
void visitTupleExtractInst (TupleExtractInst *tei) {
3811
3845
auto *tupleTy = tei->getTupleType ();
3846
+ auto tupleCotanTy = getCotangentType (tupleTy->getCanonicalType (),
3847
+ getModule ());
3812
3848
auto av = getAdjointValue (tei);
3813
3849
switch (av.getKind ()) {
3814
3850
case AdjointValue::Kind::Zero:
3815
- addAdjointValue (tei->getOperand (),
3816
- AdjointValue::getZero (SILType::getPrimitiveObjectType (
3817
- tupleTy->getCanonicalType ())));
3851
+ addAdjointValue (tei->getOperand (), AdjointValue::getZero (tupleCotanTy));
3818
3852
break ;
3819
3853
case AdjointValue::Kind::Aggregate:
3820
3854
case AdjointValue::Kind::Materialized: {
3821
3855
SmallVector<AdjointValue, 8 > elements;
3822
3856
for (unsigned i : range (tupleTy->getNumElements ())) {
3823
3857
if (tei->getFieldNo () == i)
3824
3858
elements.push_back (av);
3825
- else {
3826
- auto eltTy = SILType::getPrimitiveObjectType (
3827
- tupleTy->getElementType (i)->getCanonicalType ());
3828
- elements.push_back (AdjointValue::getZero (eltTy));
3829
- }
3859
+ else
3860
+ elements.push_back (AdjointValue::getZero (
3861
+ getCotangentType (tupleTy->getElementType (i)->getCanonicalType (),
3862
+ getModule ())));
3830
3863
}
3831
3864
addAdjointValue (tei->getOperand (),
3832
- AdjointValue::getAggregate (tei->getOperand ()->getType (),
3833
- elements, allocator));
3865
+ AdjointValue::getAggregate (tupleCotanTy, elements, allocator));
3834
3866
break ;
3835
3867
}
3836
3868
}
@@ -4406,7 +4438,7 @@ void DifferentiationTask::createEmptyPrimal() {
4406
4438
auto linkage = SILLinkage::Hidden;
4407
4439
primal = fb.getOrCreateFunction (
4408
4440
original->getLocation (), primalName, linkage, primalTy,
4409
- original->isBare (), original-> isTransparent () , original->isSerialized ());
4441
+ original->isBare (), IsNotTransparent , original->isSerialized ());
4410
4442
primal->setUnqualifiedOwnership ();
4411
4443
LLVM_DEBUG (getADDebugStream () << " Primal function created \n "
4412
4444
<< *primal << ' \n ' );
@@ -4531,7 +4563,7 @@ void DifferentiationTask::createEmptyAdjoint() {
4531
4563
auto linkage = SILLinkage::Hidden;
4532
4564
adjoint = fb.createFunction (
4533
4565
linkage, adjName, adjType, original->getGenericEnvironment (),
4534
- original->getLocation (), original->isBare (), original-> isTransparent () ,
4566
+ original->getLocation (), original->isBare (), IsNotTransparent ,
4535
4567
original->isSerialized ());
4536
4568
adjoint->setUnqualifiedOwnership ();
4537
4569
adjoint->setDebugScope (new (module )
@@ -4556,7 +4588,7 @@ void DifferentiationTask::createJVP() {
4556
4588
jvp = fb.createFunction (original->getLinkage (), jvpName, jvpType,
4557
4589
original->getGenericEnvironment (),
4558
4590
original->getLocation (), original->isBare (),
4559
- original-> isTransparent () , original->isSerialized ());
4591
+ IsNotTransparent , original->isSerialized ());
4560
4592
jvp->setUnqualifiedOwnership ();
4561
4593
jvp->setDebugScope (new (module ) SILDebugScope (original->getLocation (), jvp));
4562
4594
attr->setJVPName (jvp->getName ());
@@ -4613,7 +4645,7 @@ void DifferentiationTask::createVJP() {
4613
4645
vjp = fb.createFunction (linkage, vjpName, vjpType,
4614
4646
original->getGenericEnvironment (),
4615
4647
original->getLocation (), original->isBare (),
4616
- original-> isTransparent () , original->isSerialized ());
4648
+ IsNotTransparent , original->isSerialized ());
4617
4649
vjp->setUnqualifiedOwnership ();
4618
4650
vjp->setDebugScope (new (module )
4619
4651
SILDebugScope (original->getLocation (), vjp));
0 commit comments