@@ -1612,6 +1612,141 @@ class OriginallyDefinedInAttr: public DeclAttribute {
1612
1612
}
1613
1613
};
1614
1614
1615
+ // / A declaration name with location.
1616
+ struct DeclNameWithLoc {
1617
+ DeclName Name;
1618
+ DeclNameLoc Loc;
1619
+ };
1620
+
1621
+ // / Attribute that marks a function as differentiable and optionally specifies
1622
+ // / custom associated derivative functions: 'jvp' and 'vjp'.
1623
+ // /
1624
+ // / Examples:
1625
+ // / @differentiable(jvp: jvpFoo where T : FloatingPoint)
1626
+ // / @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1627
+ class DifferentiableAttr final
1628
+ : public DeclAttribute,
1629
+ private llvm::TrailingObjects<DifferentiableAttr,
1630
+ ParsedAutoDiffParameter> {
1631
+ friend TrailingObjects;
1632
+
1633
+ // / Whether this function is linear (optional).
1634
+ bool Linear;
1635
+ // / The number of parsed parameters specified in 'wrt:'.
1636
+ unsigned NumParsedParameters = 0 ;
1637
+ // / The JVP function.
1638
+ Optional<DeclNameWithLoc> JVP;
1639
+ // / The VJP function.
1640
+ Optional<DeclNameWithLoc> VJP;
1641
+ // / The JVP function (optional), resolved by the type checker if JVP name is
1642
+ // / specified.
1643
+ FuncDecl *JVPFunction = nullptr ;
1644
+ // / The VJP function (optional), resolved by the type checker if VJP name is
1645
+ // / specified.
1646
+ FuncDecl *VJPFunction = nullptr ;
1647
+ // / The differentiation parameters' indices, resolved by the type checker.
1648
+ IndexSubset *ParameterIndices = nullptr ;
1649
+ // / The trailing where clause (optional).
1650
+ TrailingWhereClause *WhereClause = nullptr ;
1651
+ // / The generic signature for autodiff associated functions. Resolved by the
1652
+ // / type checker based on the original function's generic signature and the
1653
+ // / attribute's where clause requirements. This is set only if the attribute
1654
+ // / has a where clause.
1655
+ GenericSignature DerivativeGenericSignature;
1656
+
1657
+ explicit DifferentiableAttr (bool implicit, SourceLoc atLoc,
1658
+ SourceRange baseRange, bool linear,
1659
+ ArrayRef<ParsedAutoDiffParameter> parameters,
1660
+ Optional<DeclNameWithLoc> jvp,
1661
+ Optional<DeclNameWithLoc> vjp,
1662
+ TrailingWhereClause *clause);
1663
+
1664
+ explicit DifferentiableAttr (Decl *original, bool implicit, SourceLoc atLoc,
1665
+ SourceRange baseRange, bool linear,
1666
+ IndexSubset *parameterIndices,
1667
+ Optional<DeclNameWithLoc> jvp,
1668
+ Optional<DeclNameWithLoc> vjp,
1669
+ GenericSignature derivativeGenericSignature);
1670
+
1671
+ public:
1672
+ static DifferentiableAttr *create (ASTContext &context, bool implicit,
1673
+ SourceLoc atLoc, SourceRange baseRange,
1674
+ bool linear,
1675
+ ArrayRef<ParsedAutoDiffParameter> params,
1676
+ Optional<DeclNameWithLoc> jvp,
1677
+ Optional<DeclNameWithLoc> vjp,
1678
+ TrailingWhereClause *clause);
1679
+
1680
+ static DifferentiableAttr *create (AbstractFunctionDecl *original,
1681
+ bool implicit, SourceLoc atLoc,
1682
+ SourceRange baseRange, bool linear,
1683
+ IndexSubset *parameterIndices,
1684
+ Optional<DeclNameWithLoc> jvp,
1685
+ Optional<DeclNameWithLoc> vjp,
1686
+ GenericSignature derivativeGenSig);
1687
+
1688
+ // / Get the optional 'jvp:' function name and location.
1689
+ // / Use this instead of `getJVPFunction` to check whether the attribute has a
1690
+ // / registered JVP.
1691
+ Optional<DeclNameWithLoc> getJVP () const { return JVP; }
1692
+
1693
+ // / Get the optional 'vjp:' function name and location.
1694
+ // / Use this instead of `getVJPFunction` to check whether the attribute has a
1695
+ // / registered VJP.
1696
+ Optional<DeclNameWithLoc> getVJP () const { return VJP; }
1697
+
1698
+ IndexSubset *getParameterIndices () const {
1699
+ return ParameterIndices;
1700
+ }
1701
+ void setParameterIndices (IndexSubset *parameterIndices) {
1702
+ ParameterIndices = parameterIndices;
1703
+ }
1704
+
1705
+ // / The parsed differentiation parameters, i.e. the list of parameters
1706
+ // / specified in 'wrt:'.
1707
+ ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
1708
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1709
+ }
1710
+ MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
1711
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1712
+ }
1713
+ size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
1714
+ return NumParsedParameters;
1715
+ }
1716
+
1717
+ bool isLinear () const { return Linear; }
1718
+
1719
+ TrailingWhereClause *getWhereClause () const { return WhereClause; }
1720
+
1721
+ GenericSignature getDerivativeGenericSignature () const {
1722
+ return DerivativeGenericSignature;
1723
+ }
1724
+ void setDerivativeGenericSignature (GenericSignature derivativeGenSig) {
1725
+ DerivativeGenericSignature = derivativeGenSig;
1726
+ }
1727
+
1728
+ FuncDecl *getJVPFunction () const { return JVPFunction; }
1729
+ void setJVPFunction (FuncDecl *decl);
1730
+ FuncDecl *getVJPFunction () const { return VJPFunction; }
1731
+ void setVJPFunction (FuncDecl *decl);
1732
+
1733
+ // / Get the derivative generic environment for the given `@differentiable`
1734
+ // / attribute and original function.
1735
+ GenericEnvironment *
1736
+ getDerivativeGenericEnvironment (AbstractFunctionDecl *original) const ;
1737
+
1738
+ // Print the attribute to the given stream.
1739
+ // If `omitWrtClause` is true, omit printing the `wrt:` clause.
1740
+ // If `omitAssociatedFunctions` is true, omit printing associated functions.
1741
+ void print (llvm::raw_ostream &OS, const Decl *D,
1742
+ bool omitWrtClause = false ,
1743
+ bool omitAssociatedFunctions = false ) const ;
1744
+
1745
+ static bool classof (const DeclAttribute *DA) {
1746
+ return DA->getKind () == DAK_Differentiable;
1747
+ }
1748
+ };
1749
+
1615
1750
// / Attributes that may be applied to declarations.
1616
1751
class DeclAttributes {
1617
1752
// / Linked list of declaration attributes.
@@ -1791,148 +1926,6 @@ class DeclAttributes {
1791
1926
SourceLoc getStartLoc (bool forModifiers = false ) const ;
1792
1927
};
1793
1928
1794
- // / A declaration name with location.
1795
- struct DeclNameWithLoc {
1796
- DeclName Name;
1797
- DeclNameLoc Loc;
1798
- };
1799
-
1800
- // / Attribute that marks a function as differentiable and optionally specifies
1801
- // / custom associated derivative functions: 'jvp' and 'vjp'.
1802
- // /
1803
- // / Examples:
1804
- // / @differentiable(jvp: jvpFoo where T : FloatingPoint)
1805
- // / @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1806
- class DifferentiableAttr final
1807
- : public DeclAttribute,
1808
- private llvm::TrailingObjects<DifferentiableAttr,
1809
- ParsedAutoDiffParameter> {
1810
- friend TrailingObjects;
1811
-
1812
- // / Whether this function is linear (optional).
1813
- bool linear;
1814
- // / The number of parsed parameters specified in 'wrt:'.
1815
- unsigned NumParsedParameters = 0 ;
1816
- // / The JVP function.
1817
- Optional<DeclNameWithLoc> JVP;
1818
- // / The VJP function.
1819
- Optional<DeclNameWithLoc> VJP;
1820
- // / The JVP function (optional), resolved by the type checker if JVP name is
1821
- // / specified.
1822
- FuncDecl *JVPFunction = nullptr ;
1823
- // / The VJP function (optional), resolved by the type checker if VJP name is
1824
- // / specified.
1825
- FuncDecl *VJPFunction = nullptr ;
1826
- // / The differentiation parameters' indices, resolved by the type checker.
1827
- IndexSubset *ParameterIndices = nullptr ;
1828
- // / The trailing where clause (optional).
1829
- TrailingWhereClause *WhereClause = nullptr ;
1830
- // / The generic signature for autodiff associated functions. Resolved by the
1831
- // / type checker based on the original function's generic signature and the
1832
- // / attribute's where clause requirements. This is set only if the attribute
1833
- // / has a where clause.
1834
- GenericSignature DerivativeGenericSignature;
1835
-
1836
- explicit DifferentiableAttr (ASTContext &context, bool implicit,
1837
- SourceLoc atLoc, SourceRange baseRange,
1838
- bool linear,
1839
- ArrayRef<ParsedAutoDiffParameter> parameters,
1840
- Optional<DeclNameWithLoc> jvp,
1841
- Optional<DeclNameWithLoc> vjp,
1842
- TrailingWhereClause *clause);
1843
-
1844
- explicit DifferentiableAttr (ASTContext &context, bool implicit,
1845
- SourceLoc atLoc, SourceRange baseRange,
1846
- bool linear, IndexSubset *indices,
1847
- Optional<DeclNameWithLoc> jvp,
1848
- Optional<DeclNameWithLoc> vjp,
1849
- GenericSignature derivativeGenericSignature);
1850
-
1851
- public:
1852
- static DifferentiableAttr *create (ASTContext &context, bool implicit,
1853
- SourceLoc atLoc, SourceRange baseRange,
1854
- bool linear,
1855
- ArrayRef<ParsedAutoDiffParameter> params,
1856
- Optional<DeclNameWithLoc> jvp,
1857
- Optional<DeclNameWithLoc> vjp,
1858
- TrailingWhereClause *clause);
1859
-
1860
- static DifferentiableAttr *create (ASTContext &context, bool implicit,
1861
- SourceLoc atLoc, SourceRange baseRange,
1862
- bool linear, IndexSubset *indices,
1863
- Optional<DeclNameWithLoc> jvp,
1864
- Optional<DeclNameWithLoc> vjp,
1865
- GenericSignature derivativeGenSig);
1866
-
1867
- // / Get the optional 'jvp:' function name and location.
1868
- // / Use this instead of `getJVPFunction` to check whether the attribute has a
1869
- // / registered JVP.
1870
- Optional<DeclNameWithLoc> getJVP () const { return JVP; }
1871
-
1872
- // / Get the optional 'vjp:' function name and location.
1873
- // / Use this instead of `getVJPFunction` to check whether the attribute has a
1874
- // / registered VJP.
1875
- Optional<DeclNameWithLoc> getVJP () const { return VJP; }
1876
-
1877
- IndexSubset *getParameterIndices () const {
1878
- return ParameterIndices;
1879
- }
1880
- void setParameterIndices (IndexSubset *pi) {
1881
- ParameterIndices = pi;
1882
- }
1883
-
1884
- // / The parsed differentiation parameters, i.e. the list of parameters
1885
- // / specified in 'wrt:'.
1886
- ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
1887
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1888
- }
1889
- MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
1890
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1891
- }
1892
- size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
1893
- return NumParsedParameters;
1894
- }
1895
-
1896
- bool isLinear () const { return linear; }
1897
-
1898
- TrailingWhereClause *getWhereClause () const { return WhereClause; }
1899
-
1900
- GenericSignature getDerivativeGenericSignature () const {
1901
- return DerivativeGenericSignature;
1902
- }
1903
- void setDerivativeGenericSignature (ASTContext &context,
1904
- GenericSignature derivativeGenSig) {
1905
- DerivativeGenericSignature = derivativeGenSig;
1906
- }
1907
-
1908
- FuncDecl *getJVPFunction () const { return JVPFunction; }
1909
- void setJVPFunction (FuncDecl *decl);
1910
- FuncDecl *getVJPFunction () const { return VJPFunction; }
1911
- void setVJPFunction (FuncDecl *decl);
1912
-
1913
- bool parametersMatch (const DifferentiableAttr &other) const {
1914
- assert (ParameterIndices && other.ParameterIndices );
1915
- return ParameterIndices == other.ParameterIndices ;
1916
- }
1917
-
1918
- // / Get the derivative generic environment for the given `@differentiable`
1919
- // / attribute and original function.
1920
- GenericEnvironment *
1921
- getDerivativeGenericEnvironment (AbstractFunctionDecl *original) const ;
1922
-
1923
- // Print the attribute to the given stream.
1924
- // If `omitWrtClause` is true, omit printing the `wrt:` clause.
1925
- // If `omitAssociatedFunctions` is true, omit printing associated functions.
1926
- void print (llvm::raw_ostream &OS, const Decl *D,
1927
- bool omitWrtClause = false ,
1928
- bool omitAssociatedFunctions = false ) const ;
1929
-
1930
- static bool classof (const DeclAttribute *DA) {
1931
- return DA->getKind () == DAK_Differentiable;
1932
- }
1933
- };
1934
-
1935
-
1936
1929
void simple_display (llvm::raw_ostream &out, const DeclAttribute *attr);
1937
1930
1938
1931
inline SourceLoc extractNearestSourceLoc (const DeclAttribute *attr) {
0 commit comments