@@ -719,51 +719,62 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
719
719
}
720
720
721
721
// SWIFT_ENABLE_TENSORFLOW
722
- // [differentiable] attributes only make sense on functions with
723
- // bodies, because [differentiable] attributes declare actual primals
724
- // and adjoints corresponding to the function body.
722
+ // [differentiable] attributes only make sense on functions with bodies,
723
+ // because [differentiable] attributes declare actual associated functions
724
+ // corresponding to the function body.
725
725
if (!AFD->hasBody ())
726
726
return ;
727
727
728
- // If the declaration has a @differentiable(reverse) attribute, turn it into a
729
- // SIL [differentiable] attribute with lowered associated function names and
730
- // lowered differentiation parameter indices.
731
- //
728
+ // Look for a @differentiable attribute on the decl.
732
729
// FIXME: Handle multiple @differentiable attributes.
733
- if (auto *diffAttr = cast_or_null<DifferentiableAttr>(
734
- AFD->getAttrs ().getAttribute (DeclAttrKind::DAK_Differentiable))) {
735
- auto silOriginalFn = getFunction (SILDeclRef (AFD), ForDefinition);
736
- // Either only adjoint is specified, or both primal and adjoint are
737
- // spcified.
738
- StringRef primName, adjName, jvpName, vjpName;
739
- bool hasPrimitiveAdjoint = false ;
740
- if (auto *primFn = diffAttr->getPrimalFunction ())
741
- primName = getFunction (SILDeclRef (primFn), ForDefinition)->getName ();
742
- if (auto *adjointFn = diffAttr->getAdjointFunction ()) {
743
- // If the adjoint is specified but the primal is not, then we treat the
744
- // original as the primal.
745
- if (primName.empty ())
746
- primName = silOriginalFn->getName ();
747
- adjName = getFunction (SILDeclRef (adjointFn), ForDefinition)->getName ();
748
- hasPrimitiveAdjoint = true ;
749
- }
750
- else {
751
- assert (primName.empty () &&
752
- " Primal cannot be present if adjoint is not" );
730
+ DifferentiableAttr *diffAttr = nullptr ;
731
+ if (AFD->getAttrs ().hasAttribute <DifferentiableAttr>())
732
+ diffAttr = AFD->getAttrs ().getAttribute <DifferentiableAttr>();
733
+ // If the AFD is the getter for a storage decl, also look for a
734
+ // @differentiable attribute on the storage decl, because @differentiable
735
+ // attributes on storage decls modify the getter.
736
+ if (auto *accessor = dyn_cast<AccessorDecl>(AFD)) {
737
+ if (accessor->isGetter ()) {
738
+ auto &storageAttrs = accessor->getStorage ()->getAttrs ();
739
+ if (storageAttrs.hasAttribute <DifferentiableAttr>())
740
+ diffAttr = storageAttrs.getAttribute <DifferentiableAttr>();
753
741
}
754
- if (auto *jvpFn = diffAttr->getJVPFunction ())
755
- jvpName = getFunction (SILDeclRef (jvpFn), ForDefinition)->getName ();
756
- if (auto *vjpFn = diffAttr->getVJPFunction ())
757
- vjpName = getFunction (SILDeclRef (vjpFn), ForDefinition)->getName ();
758
- // Get lowered argument indices.
759
- auto paramIndices = diffAttr->getCheckedParameterIndices ()->getLowered (
760
- AFD->getInterfaceType ()->castTo <AnyFunctionType>());
761
- SILAutoDiffIndices indices (/* source*/ 0 , paramIndices);
762
- silOriginalFn->addDifferentiableAttr (
763
- SILDifferentiableAttr::create (
764
- M, indices, primName, adjName,
765
- /* primitive*/ hasPrimitiveAdjoint, jvpName, vjpName));
766
742
}
743
+
744
+ if (!diffAttr)
745
+ return ;
746
+
747
+ // The declaration (or its storage decl) has a @differentiable attribute, so
748
+ // turn it into a SIL [differentiable] attribute with lowered associated
749
+ // function names and lowered differentiation parameter indices.
750
+ auto silOriginalFn = getFunction (SILDeclRef (AFD), ForDefinition);
751
+ // Either only adjoint is specified, or both primal and adjoint are
752
+ // spcified.
753
+ StringRef primName, adjName, jvpName, vjpName;
754
+ bool hasPrimitiveAdjoint = false ;
755
+ if (auto *primFn = diffAttr->getPrimalFunction ())
756
+ primName = getFunction (SILDeclRef (primFn), ForDefinition)->getName ();
757
+ if (auto *adjointFn = diffAttr->getAdjointFunction ()) {
758
+ // If the adjoint is specified but the primal is not, then we treat the
759
+ // original as the primal.
760
+ if (primName.empty ())
761
+ primName = silOriginalFn->getName ();
762
+ adjName = getFunction (SILDeclRef (adjointFn), ForDefinition)->getName ();
763
+ hasPrimitiveAdjoint = true ;
764
+ } else {
765
+ assert (primName.empty () && " Primal cannot be present if adjoint is not" );
766
+ }
767
+ if (auto *jvpFn = diffAttr->getJVPFunction ())
768
+ jvpName = getFunction (SILDeclRef (jvpFn), ForDefinition)->getName ();
769
+ if (auto *vjpFn = diffAttr->getVJPFunction ())
770
+ vjpName = getFunction (SILDeclRef (vjpFn), ForDefinition)->getName ();
771
+ // Get lowered argument indices.
772
+ auto paramIndices = diffAttr->getCheckedParameterIndices ()->getLowered (
773
+ AFD->getInterfaceType ()->castTo <AnyFunctionType>());
774
+ SILAutoDiffIndices indices (/* source*/ 0 , paramIndices);
775
+ silOriginalFn->addDifferentiableAttr (SILDifferentiableAttr::create (
776
+ M, indices, primName, adjName,
777
+ /* primitive*/ hasPrimitiveAdjoint, jvpName, vjpName));
767
778
}
768
779
769
780
void SILGenModule::emitFunction (FuncDecl *fd) {
0 commit comments