|
27 | 27 | #include "swift/Basic/Version.h"
|
28 | 28 | #include "swift/AST/Identifier.h"
|
29 | 29 | #include "swift/AST/AttrKind.h"
|
| 30 | +#include "swift/AST/AutoDiff.h" |
30 | 31 | #include "swift/AST/ConcreteDeclRef.h"
|
31 | 32 | #include "swift/AST/DeclNameLoc.h"
|
32 | 33 | #include "swift/AST/KnownProtocols.h"
|
@@ -1724,6 +1725,148 @@ class DeclAttributes {
|
1724 | 1725 | SourceLoc getStartLoc(bool forModifiers = false) const;
|
1725 | 1726 | };
|
1726 | 1727 |
|
| 1728 | +/// A declaration name with location. |
| 1729 | +struct DeclNameWithLoc { |
| 1730 | + DeclName Name; |
| 1731 | + DeclNameLoc Loc; |
| 1732 | +}; |
| 1733 | + |
| 1734 | +/// Attribute that marks a function as differentiable and optionally specifies |
| 1735 | +/// custom associated derivative functions: 'jvp' and 'vjp'. |
| 1736 | +/// |
| 1737 | +/// Examples: |
| 1738 | +/// @differentiable(jvp: jvpFoo where T : FloatingPoint) |
| 1739 | +/// @differentiable(wrt: (self, x, y), jvp: jvpFoo) |
| 1740 | +class DifferentiableAttr final |
| 1741 | + : public DeclAttribute, |
| 1742 | + private llvm::TrailingObjects<DifferentiableAttr, |
| 1743 | + ParsedAutoDiffParameter> { |
| 1744 | + friend TrailingObjects; |
| 1745 | + |
| 1746 | + /// Whether this function is linear (optional). |
| 1747 | + bool linear; |
| 1748 | + /// The number of parsed parameters specified in 'wrt:'. |
| 1749 | + unsigned NumParsedParameters = 0; |
| 1750 | + /// The JVP function. |
| 1751 | + Optional<DeclNameWithLoc> JVP; |
| 1752 | + /// The VJP function. |
| 1753 | + Optional<DeclNameWithLoc> VJP; |
| 1754 | + /// The JVP function (optional), resolved by the type checker if JVP name is |
| 1755 | + /// specified. |
| 1756 | + FuncDecl *JVPFunction = nullptr; |
| 1757 | + /// The VJP function (optional), resolved by the type checker if VJP name is |
| 1758 | + /// specified. |
| 1759 | + FuncDecl *VJPFunction = nullptr; |
| 1760 | + /// The differentiation parameters' indices, resolved by the type checker. |
| 1761 | + AutoDiffIndexSubset *ParameterIndices = nullptr; |
| 1762 | + /// The trailing where clause (optional). |
| 1763 | + TrailingWhereClause *WhereClause = nullptr; |
| 1764 | + /// The generic signature for autodiff associated functions. Resolved by the |
| 1765 | + /// type checker based on the original function's generic signature and the |
| 1766 | + /// attribute's where clause requirements. This is set only if the attribute |
| 1767 | + /// has a where clause. |
| 1768 | + GenericSignature DerivativeGenericSignature; |
| 1769 | + |
| 1770 | + explicit DifferentiableAttr(ASTContext &context, bool implicit, |
| 1771 | + SourceLoc atLoc, SourceRange baseRange, |
| 1772 | + bool linear, |
| 1773 | + ArrayRef<ParsedAutoDiffParameter> parameters, |
| 1774 | + Optional<DeclNameWithLoc> jvp, |
| 1775 | + Optional<DeclNameWithLoc> vjp, |
| 1776 | + TrailingWhereClause *clause); |
| 1777 | + |
| 1778 | + explicit DifferentiableAttr(ASTContext &context, bool implicit, |
| 1779 | + SourceLoc atLoc, SourceRange baseRange, |
| 1780 | + bool linear, AutoDiffIndexSubset *indices, |
| 1781 | + Optional<DeclNameWithLoc> jvp, |
| 1782 | + Optional<DeclNameWithLoc> vjp, |
| 1783 | + GenericSignature derivativeGenericSignature); |
| 1784 | + |
| 1785 | +public: |
| 1786 | + static DifferentiableAttr *create(ASTContext &context, bool implicit, |
| 1787 | + SourceLoc atLoc, SourceRange baseRange, |
| 1788 | + bool linear, |
| 1789 | + ArrayRef<ParsedAutoDiffParameter> params, |
| 1790 | + Optional<DeclNameWithLoc> jvp, |
| 1791 | + Optional<DeclNameWithLoc> vjp, |
| 1792 | + TrailingWhereClause *clause); |
| 1793 | + |
| 1794 | + static DifferentiableAttr *create(ASTContext &context, bool implicit, |
| 1795 | + SourceLoc atLoc, SourceRange baseRange, |
| 1796 | + bool linear, AutoDiffIndexSubset *indices, |
| 1797 | + Optional<DeclNameWithLoc> jvp, |
| 1798 | + Optional<DeclNameWithLoc> vjp, |
| 1799 | + GenericSignature derivativeGenSig); |
| 1800 | + |
| 1801 | + /// Get the optional 'jvp:' function name and location. |
| 1802 | + /// Use this instead of `getJVPFunction` to check whether the attribute has a |
| 1803 | + /// registered JVP. |
| 1804 | + Optional<DeclNameWithLoc> getJVP() const { return JVP; } |
| 1805 | + |
| 1806 | + /// Get the optional 'vjp:' function name and location. |
| 1807 | + /// Use this instead of `getVJPFunction` to check whether the attribute has a |
| 1808 | + /// registered VJP. |
| 1809 | + Optional<DeclNameWithLoc> getVJP() const { return VJP; } |
| 1810 | + |
| 1811 | + AutoDiffIndexSubset *getParameterIndices() const { |
| 1812 | + return ParameterIndices; |
| 1813 | + } |
| 1814 | + void setParameterIndices(AutoDiffIndexSubset *pi) { |
| 1815 | + ParameterIndices = pi; |
| 1816 | + } |
| 1817 | + |
| 1818 | + /// The parsed differentiation parameters, i.e. the list of parameters |
| 1819 | + /// specified in 'wrt:'. |
| 1820 | + ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const { |
| 1821 | + return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters}; |
| 1822 | + } |
| 1823 | + MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() { |
| 1824 | + return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters}; |
| 1825 | + } |
| 1826 | + size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const { |
| 1827 | + return NumParsedParameters; |
| 1828 | + } |
| 1829 | + |
| 1830 | + bool isLinear() const { return linear; } |
| 1831 | + |
| 1832 | + TrailingWhereClause *getWhereClause() const { return WhereClause; } |
| 1833 | + |
| 1834 | + GenericSignature getDerivativeGenericSignature() const { |
| 1835 | + return DerivativeGenericSignature; |
| 1836 | + } |
| 1837 | + void setDerivativeGenericSignature(ASTContext &context, |
| 1838 | + GenericSignature derivativeGenSig) { |
| 1839 | + DerivativeGenericSignature = derivativeGenSig; |
| 1840 | + } |
| 1841 | + |
| 1842 | + FuncDecl *getJVPFunction() const { return JVPFunction; } |
| 1843 | + void setJVPFunction(FuncDecl *decl); |
| 1844 | + FuncDecl *getVJPFunction() const { return VJPFunction; } |
| 1845 | + void setVJPFunction(FuncDecl *decl); |
| 1846 | + |
| 1847 | + bool parametersMatch(const DifferentiableAttr &other) const { |
| 1848 | + assert(ParameterIndices && other.ParameterIndices); |
| 1849 | + return ParameterIndices == other.ParameterIndices; |
| 1850 | + } |
| 1851 | + |
| 1852 | + /// Get the derivative generic environment for the given `@differentiable` |
| 1853 | + /// attribute and original function. |
| 1854 | + GenericEnvironment * |
| 1855 | + getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const; |
| 1856 | + |
| 1857 | + // Print the attribute to the given stream. |
| 1858 | + // If `omitWrtClause` is true, omit printing the `wrt:` clause. |
| 1859 | + // If `omitAssociatedFunctions` is true, omit printing associated functions. |
| 1860 | + void print(llvm::raw_ostream &OS, const Decl *D, |
| 1861 | + bool omitWrtClause = false, |
| 1862 | + bool omitAssociatedFunctions = false) const; |
| 1863 | + |
| 1864 | + static bool classof(const DeclAttribute *DA) { |
| 1865 | + return DA->getKind() == DAK_Differentiable; |
| 1866 | + } |
| 1867 | +}; |
| 1868 | + |
| 1869 | + |
1727 | 1870 | void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
|
1728 | 1871 |
|
1729 | 1872 | inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {
|
|
0 commit comments