Skip to content

Commit 374e64d

Browse files
committed
Upstream @differentiable attribute parsing.
1 parent 937e57f commit 374e64d

File tree

14 files changed

+1509
-1
lines changed

14 files changed

+1509
-1
lines changed

include/swift/AST/Attr.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,12 @@ SIMPLE_DECL_ATTR(_nonEphemeral, NonEphemeral,
502502
ABIStableToAdd | ABIStableToRemove | APIBreakingToAdd | APIStableToRemove,
503503
90)
504504

505+
DECL_ATTR(differentiable, Differentiable,
506+
OnAccessor | OnConstructor | OnFunc | OnVar | OnSubscript | LongAttribute |
507+
AllowMultipleAttributes |
508+
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
509+
90)
510+
505511
SIMPLE_DECL_ATTR(IBSegueAction, IBSegueAction,
506512
OnFunc |
507513
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,

include/swift/AST/Attr.h

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/Basic/Version.h"
2828
#include "swift/AST/Identifier.h"
2929
#include "swift/AST/AttrKind.h"
30+
#include "swift/AST/AutoDiff.h"
3031
#include "swift/AST/ConcreteDeclRef.h"
3132
#include "swift/AST/DeclNameLoc.h"
3233
#include "swift/AST/KnownProtocols.h"
@@ -1724,6 +1725,148 @@ class DeclAttributes {
17241725
SourceLoc getStartLoc(bool forModifiers = false) const;
17251726
};
17261727

1728+
/// A declaration name with location.
1729+
struct DeclNameWithLoc {
1730+
DeclName Name;
1731+
DeclNameLoc Loc;
1732+
};
1733+
1734+
/// Attribute that marks a function as differentiable and optionally specifies
1735+
/// custom associated derivative functions: 'jvp' and 'vjp'.
1736+
///
1737+
/// Examples:
1738+
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1739+
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1740+
class DifferentiableAttr final
1741+
: public DeclAttribute,
1742+
private llvm::TrailingObjects<DifferentiableAttr,
1743+
ParsedAutoDiffParameter> {
1744+
friend TrailingObjects;
1745+
1746+
/// Whether this function is linear (optional).
1747+
bool linear;
1748+
/// The number of parsed parameters specified in 'wrt:'.
1749+
unsigned NumParsedParameters = 0;
1750+
/// The JVP function.
1751+
Optional<DeclNameWithLoc> JVP;
1752+
/// The VJP function.
1753+
Optional<DeclNameWithLoc> VJP;
1754+
/// The JVP function (optional), resolved by the type checker if JVP name is
1755+
/// specified.
1756+
FuncDecl *JVPFunction = nullptr;
1757+
/// The VJP function (optional), resolved by the type checker if VJP name is
1758+
/// specified.
1759+
FuncDecl *VJPFunction = nullptr;
1760+
/// The differentiation parameters' indices, resolved by the type checker.
1761+
AutoDiffIndexSubset *ParameterIndices = nullptr;
1762+
/// The trailing where clause (optional).
1763+
TrailingWhereClause *WhereClause = nullptr;
1764+
/// The generic signature for autodiff associated functions. Resolved by the
1765+
/// type checker based on the original function's generic signature and the
1766+
/// attribute's where clause requirements. This is set only if the attribute
1767+
/// has a where clause.
1768+
GenericSignature DerivativeGenericSignature;
1769+
1770+
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1771+
SourceLoc atLoc, SourceRange baseRange,
1772+
bool linear,
1773+
ArrayRef<ParsedAutoDiffParameter> parameters,
1774+
Optional<DeclNameWithLoc> jvp,
1775+
Optional<DeclNameWithLoc> vjp,
1776+
TrailingWhereClause *clause);
1777+
1778+
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1779+
SourceLoc atLoc, SourceRange baseRange,
1780+
bool linear, AutoDiffIndexSubset *indices,
1781+
Optional<DeclNameWithLoc> jvp,
1782+
Optional<DeclNameWithLoc> vjp,
1783+
GenericSignature derivativeGenericSignature);
1784+
1785+
public:
1786+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1787+
SourceLoc atLoc, SourceRange baseRange,
1788+
bool linear,
1789+
ArrayRef<ParsedAutoDiffParameter> params,
1790+
Optional<DeclNameWithLoc> jvp,
1791+
Optional<DeclNameWithLoc> vjp,
1792+
TrailingWhereClause *clause);
1793+
1794+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1795+
SourceLoc atLoc, SourceRange baseRange,
1796+
bool linear, AutoDiffIndexSubset *indices,
1797+
Optional<DeclNameWithLoc> jvp,
1798+
Optional<DeclNameWithLoc> vjp,
1799+
GenericSignature derivativeGenSig);
1800+
1801+
/// Get the optional 'jvp:' function name and location.
1802+
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1803+
/// registered JVP.
1804+
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1805+
1806+
/// Get the optional 'vjp:' function name and location.
1807+
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1808+
/// registered VJP.
1809+
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1810+
1811+
AutoDiffIndexSubset *getParameterIndices() const {
1812+
return ParameterIndices;
1813+
}
1814+
void setParameterIndices(AutoDiffIndexSubset *pi) {
1815+
ParameterIndices = pi;
1816+
}
1817+
1818+
/// The parsed differentiation parameters, i.e. the list of parameters
1819+
/// specified in 'wrt:'.
1820+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1821+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1822+
}
1823+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1824+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1825+
}
1826+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1827+
return NumParsedParameters;
1828+
}
1829+
1830+
bool isLinear() const { return linear; }
1831+
1832+
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1833+
1834+
GenericSignature getDerivativeGenericSignature() const {
1835+
return DerivativeGenericSignature;
1836+
}
1837+
void setDerivativeGenericSignature(ASTContext &context,
1838+
GenericSignature derivativeGenSig) {
1839+
DerivativeGenericSignature = derivativeGenSig;
1840+
}
1841+
1842+
FuncDecl *getJVPFunction() const { return JVPFunction; }
1843+
void setJVPFunction(FuncDecl *decl);
1844+
FuncDecl *getVJPFunction() const { return VJPFunction; }
1845+
void setVJPFunction(FuncDecl *decl);
1846+
1847+
bool parametersMatch(const DifferentiableAttr &other) const {
1848+
assert(ParameterIndices && other.ParameterIndices);
1849+
return ParameterIndices == other.ParameterIndices;
1850+
}
1851+
1852+
/// Get the derivative generic environment for the given `@differentiable`
1853+
/// attribute and original function.
1854+
GenericEnvironment *
1855+
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1856+
1857+
// Print the attribute to the given stream.
1858+
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1859+
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1860+
void print(llvm::raw_ostream &OS, const Decl *D,
1861+
bool omitWrtClause = false,
1862+
bool omitAssociatedFunctions = false) const;
1863+
1864+
static bool classof(const DeclAttribute *DA) {
1865+
return DA->getKind() == DAK_Differentiable;
1866+
}
1867+
};
1868+
1869+
17271870
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
17281871

17291872
inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {

0 commit comments

Comments
 (0)