Skip to content

Commit 6742368

Browse files
authored
[AutoDiff] NFC: @differentiable attribute gardening. (#28666)
- Move `DifferentiableAttr` definition above `DeclAttributes` in include/swift/AST/Attr.h, like other attributes. - Remove unnecessary arguments from `DifferentiableAttr::DifferentiableAttr` and `DifferentiableAttr::setDerivativeGenericSignature`. - Add libSyntax test for `@differentiable` attributes.
1 parent 113baa6 commit 6742368

File tree

4 files changed

+207
-160
lines changed

4 files changed

+207
-160
lines changed

include/swift/AST/Attr.h

Lines changed: 135 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,141 @@ class OriginallyDefinedInAttr: public DeclAttribute {
16121612
}
16131613
};
16141614

1615+
/// A declaration name with location.
1616+
struct DeclNameWithLoc {
1617+
DeclName Name;
1618+
DeclNameLoc Loc;
1619+
};
1620+
1621+
/// Attribute that marks a function as differentiable and optionally specifies
1622+
/// custom associated derivative functions: 'jvp' and 'vjp'.
1623+
///
1624+
/// Examples:
1625+
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1626+
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1627+
class DifferentiableAttr final
1628+
: public DeclAttribute,
1629+
private llvm::TrailingObjects<DifferentiableAttr,
1630+
ParsedAutoDiffParameter> {
1631+
friend TrailingObjects;
1632+
1633+
/// Whether this function is linear (optional).
1634+
bool Linear;
1635+
/// The number of parsed parameters specified in 'wrt:'.
1636+
unsigned NumParsedParameters = 0;
1637+
/// The JVP function.
1638+
Optional<DeclNameWithLoc> JVP;
1639+
/// The VJP function.
1640+
Optional<DeclNameWithLoc> VJP;
1641+
/// The JVP function (optional), resolved by the type checker if JVP name is
1642+
/// specified.
1643+
FuncDecl *JVPFunction = nullptr;
1644+
/// The VJP function (optional), resolved by the type checker if VJP name is
1645+
/// specified.
1646+
FuncDecl *VJPFunction = nullptr;
1647+
/// The differentiation parameters' indices, resolved by the type checker.
1648+
IndexSubset *ParameterIndices = nullptr;
1649+
/// The trailing where clause (optional).
1650+
TrailingWhereClause *WhereClause = nullptr;
1651+
/// The generic signature for autodiff associated functions. Resolved by the
1652+
/// type checker based on the original function's generic signature and the
1653+
/// attribute's where clause requirements. This is set only if the attribute
1654+
/// has a where clause.
1655+
GenericSignature DerivativeGenericSignature;
1656+
1657+
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
1658+
SourceRange baseRange, bool linear,
1659+
ArrayRef<ParsedAutoDiffParameter> parameters,
1660+
Optional<DeclNameWithLoc> jvp,
1661+
Optional<DeclNameWithLoc> vjp,
1662+
TrailingWhereClause *clause);
1663+
1664+
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
1665+
SourceRange baseRange, bool linear,
1666+
IndexSubset *parameterIndices,
1667+
Optional<DeclNameWithLoc> jvp,
1668+
Optional<DeclNameWithLoc> vjp,
1669+
GenericSignature derivativeGenericSignature);
1670+
1671+
public:
1672+
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1673+
SourceLoc atLoc, SourceRange baseRange,
1674+
bool linear,
1675+
ArrayRef<ParsedAutoDiffParameter> params,
1676+
Optional<DeclNameWithLoc> jvp,
1677+
Optional<DeclNameWithLoc> vjp,
1678+
TrailingWhereClause *clause);
1679+
1680+
static DifferentiableAttr *create(AbstractFunctionDecl *original,
1681+
bool implicit, SourceLoc atLoc,
1682+
SourceRange baseRange, bool linear,
1683+
IndexSubset *parameterIndices,
1684+
Optional<DeclNameWithLoc> jvp,
1685+
Optional<DeclNameWithLoc> vjp,
1686+
GenericSignature derivativeGenSig);
1687+
1688+
/// Get the optional 'jvp:' function name and location.
1689+
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1690+
/// registered JVP.
1691+
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1692+
1693+
/// Get the optional 'vjp:' function name and location.
1694+
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1695+
/// registered VJP.
1696+
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1697+
1698+
IndexSubset *getParameterIndices() const {
1699+
return ParameterIndices;
1700+
}
1701+
void setParameterIndices(IndexSubset *parameterIndices) {
1702+
ParameterIndices = parameterIndices;
1703+
}
1704+
1705+
/// The parsed differentiation parameters, i.e. the list of parameters
1706+
/// specified in 'wrt:'.
1707+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1708+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1709+
}
1710+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1711+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1712+
}
1713+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1714+
return NumParsedParameters;
1715+
}
1716+
1717+
bool isLinear() const { return Linear; }
1718+
1719+
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1720+
1721+
GenericSignature getDerivativeGenericSignature() const {
1722+
return DerivativeGenericSignature;
1723+
}
1724+
void setDerivativeGenericSignature(GenericSignature derivativeGenSig) {
1725+
DerivativeGenericSignature = derivativeGenSig;
1726+
}
1727+
1728+
FuncDecl *getJVPFunction() const { return JVPFunction; }
1729+
void setJVPFunction(FuncDecl *decl);
1730+
FuncDecl *getVJPFunction() const { return VJPFunction; }
1731+
void setVJPFunction(FuncDecl *decl);
1732+
1733+
/// Get the derivative generic environment for the given `@differentiable`
1734+
/// attribute and original function.
1735+
GenericEnvironment *
1736+
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1737+
1738+
// Print the attribute to the given stream.
1739+
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1740+
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1741+
void print(llvm::raw_ostream &OS, const Decl *D,
1742+
bool omitWrtClause = false,
1743+
bool omitAssociatedFunctions = false) const;
1744+
1745+
static bool classof(const DeclAttribute *DA) {
1746+
return DA->getKind() == DAK_Differentiable;
1747+
}
1748+
};
1749+
16151750
/// Attributes that may be applied to declarations.
16161751
class DeclAttributes {
16171752
/// Linked list of declaration attributes.
@@ -1791,148 +1926,6 @@ class DeclAttributes {
17911926
SourceLoc getStartLoc(bool forModifiers = false) const;
17921927
};
17931928

1794-
/// A declaration name with location.
1795-
struct DeclNameWithLoc {
1796-
DeclName Name;
1797-
DeclNameLoc Loc;
1798-
};
1799-
1800-
/// Attribute that marks a function as differentiable and optionally specifies
1801-
/// custom associated derivative functions: 'jvp' and 'vjp'.
1802-
///
1803-
/// Examples:
1804-
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1805-
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1806-
class DifferentiableAttr final
1807-
: public DeclAttribute,
1808-
private llvm::TrailingObjects<DifferentiableAttr,
1809-
ParsedAutoDiffParameter> {
1810-
friend TrailingObjects;
1811-
1812-
/// Whether this function is linear (optional).
1813-
bool linear;
1814-
/// The number of parsed parameters specified in 'wrt:'.
1815-
unsigned NumParsedParameters = 0;
1816-
/// The JVP function.
1817-
Optional<DeclNameWithLoc> JVP;
1818-
/// The VJP function.
1819-
Optional<DeclNameWithLoc> VJP;
1820-
/// The JVP function (optional), resolved by the type checker if JVP name is
1821-
/// specified.
1822-
FuncDecl *JVPFunction = nullptr;
1823-
/// The VJP function (optional), resolved by the type checker if VJP name is
1824-
/// specified.
1825-
FuncDecl *VJPFunction = nullptr;
1826-
/// The differentiation parameters' indices, resolved by the type checker.
1827-
IndexSubset *ParameterIndices = nullptr;
1828-
/// The trailing where clause (optional).
1829-
TrailingWhereClause *WhereClause = nullptr;
1830-
/// The generic signature for autodiff associated functions. Resolved by the
1831-
/// type checker based on the original function's generic signature and the
1832-
/// attribute's where clause requirements. This is set only if the attribute
1833-
/// has a where clause.
1834-
GenericSignature DerivativeGenericSignature;
1835-
1836-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1837-
SourceLoc atLoc, SourceRange baseRange,
1838-
bool linear,
1839-
ArrayRef<ParsedAutoDiffParameter> parameters,
1840-
Optional<DeclNameWithLoc> jvp,
1841-
Optional<DeclNameWithLoc> vjp,
1842-
TrailingWhereClause *clause);
1843-
1844-
explicit DifferentiableAttr(ASTContext &context, bool implicit,
1845-
SourceLoc atLoc, SourceRange baseRange,
1846-
bool linear, IndexSubset *indices,
1847-
Optional<DeclNameWithLoc> jvp,
1848-
Optional<DeclNameWithLoc> vjp,
1849-
GenericSignature derivativeGenericSignature);
1850-
1851-
public:
1852-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1853-
SourceLoc atLoc, SourceRange baseRange,
1854-
bool linear,
1855-
ArrayRef<ParsedAutoDiffParameter> params,
1856-
Optional<DeclNameWithLoc> jvp,
1857-
Optional<DeclNameWithLoc> vjp,
1858-
TrailingWhereClause *clause);
1859-
1860-
static DifferentiableAttr *create(ASTContext &context, bool implicit,
1861-
SourceLoc atLoc, SourceRange baseRange,
1862-
bool linear, IndexSubset *indices,
1863-
Optional<DeclNameWithLoc> jvp,
1864-
Optional<DeclNameWithLoc> vjp,
1865-
GenericSignature derivativeGenSig);
1866-
1867-
/// Get the optional 'jvp:' function name and location.
1868-
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1869-
/// registered JVP.
1870-
Optional<DeclNameWithLoc> getJVP() const { return JVP; }
1871-
1872-
/// Get the optional 'vjp:' function name and location.
1873-
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1874-
/// registered VJP.
1875-
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
1876-
1877-
IndexSubset *getParameterIndices() const {
1878-
return ParameterIndices;
1879-
}
1880-
void setParameterIndices(IndexSubset *pi) {
1881-
ParameterIndices = pi;
1882-
}
1883-
1884-
/// The parsed differentiation parameters, i.e. the list of parameters
1885-
/// specified in 'wrt:'.
1886-
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1887-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1888-
}
1889-
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1890-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1891-
}
1892-
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1893-
return NumParsedParameters;
1894-
}
1895-
1896-
bool isLinear() const { return linear; }
1897-
1898-
TrailingWhereClause *getWhereClause() const { return WhereClause; }
1899-
1900-
GenericSignature getDerivativeGenericSignature() const {
1901-
return DerivativeGenericSignature;
1902-
}
1903-
void setDerivativeGenericSignature(ASTContext &context,
1904-
GenericSignature derivativeGenSig) {
1905-
DerivativeGenericSignature = derivativeGenSig;
1906-
}
1907-
1908-
FuncDecl *getJVPFunction() const { return JVPFunction; }
1909-
void setJVPFunction(FuncDecl *decl);
1910-
FuncDecl *getVJPFunction() const { return VJPFunction; }
1911-
void setVJPFunction(FuncDecl *decl);
1912-
1913-
bool parametersMatch(const DifferentiableAttr &other) const {
1914-
assert(ParameterIndices && other.ParameterIndices);
1915-
return ParameterIndices == other.ParameterIndices;
1916-
}
1917-
1918-
/// Get the derivative generic environment for the given `@differentiable`
1919-
/// attribute and original function.
1920-
GenericEnvironment *
1921-
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
1922-
1923-
// Print the attribute to the given stream.
1924-
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1925-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1926-
void print(llvm::raw_ostream &OS, const Decl *D,
1927-
bool omitWrtClause = false,
1928-
bool omitAssociatedFunctions = false) const;
1929-
1930-
static bool classof(const DeclAttribute *DA) {
1931-
return DA->getKind() == DAK_Differentiable;
1932-
}
1933-
};
1934-
1935-
19361929
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
19371930

19381931
inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {

lib/AST/Attr.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,31 +1365,30 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
13651365
specializedSignature);
13661366
}
13671367

1368-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1369-
SourceLoc atLoc, SourceRange baseRange,
1370-
bool linear,
1368+
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
1369+
SourceRange baseRange, bool linear,
13711370
ArrayRef<ParsedAutoDiffParameter> params,
13721371
Optional<DeclNameWithLoc> jvp,
13731372
Optional<DeclNameWithLoc> vjp,
13741373
TrailingWhereClause *clause)
13751374
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1376-
linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
1375+
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
13771376
VJP(std::move(vjp)), WhereClause(clause) {
13781377
std::copy(params.begin(), params.end(),
13791378
getTrailingObjects<ParsedAutoDiffParameter>());
13801379
}
13811380

1382-
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
1381+
DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
13831382
SourceLoc atLoc, SourceRange baseRange,
13841383
bool linear,
1385-
IndexSubset *indices,
1384+
IndexSubset *parameterIndices,
13861385
Optional<DeclNameWithLoc> jvp,
13871386
Optional<DeclNameWithLoc> vjp,
13881387
GenericSignature derivativeGenSig)
13891388
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1390-
linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)),
1391-
ParameterIndices(indices) {
1392-
setDerivativeGenericSignature(context, derivativeGenSig);
1389+
Linear(linear), JVP(std::move(jvp)), VJP(std::move(vjp)) {
1390+
setParameterIndices(parameterIndices);
1391+
setDerivativeGenericSignature(derivativeGenSig);
13931392
}
13941393

13951394
DifferentiableAttr *
@@ -1402,22 +1401,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
14021401
TrailingWhereClause *clause) {
14031402
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
14041403
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
1405-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1406-
linear, parameters, std::move(jvp),
1404+
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
1405+
parameters, std::move(jvp),
14071406
std::move(vjp), clause);
14081407
}
14091408

14101409
DifferentiableAttr *
1411-
DifferentiableAttr::create(ASTContext &context, bool implicit,
1412-
SourceLoc atLoc, SourceRange baseRange,
1413-
bool linear, IndexSubset *indices,
1410+
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
1411+
SourceLoc atLoc, SourceRange baseRange, bool linear,
1412+
IndexSubset *parameterIndices,
14141413
Optional<DeclNameWithLoc> jvp,
14151414
Optional<DeclNameWithLoc> vjp,
14161415
GenericSignature derivativeGenSig) {
1417-
void *mem = context.Allocate(sizeof(DifferentiableAttr),
1418-
alignof(DifferentiableAttr));
1419-
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
1420-
linear, indices, std::move(jvp),
1416+
auto &ctx = original->getASTContext();
1417+
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
1418+
alignof(DifferentiableAttr));
1419+
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
1420+
linear, parameterIndices, std::move(jvp),
14211421
std::move(vjp), derivativeGenSig);
14221422
}
14231423

0 commit comments

Comments
 (0)