Skip to content

Commit 1be86ad

Browse files
author
marcrasi
authored
[AutoDiff] forbid derivative registration using @differentiable (#30001)
Delete `@differentiable` attribute `jvp:` and `vjp:` arguments for derivative registration. `@derivative` attribute is now the canonical way to register derivatives. Resolves TF-1001.
1 parent d0fb9c6 commit 1be86ad

18 files changed

+138
-1063
lines changed

include/swift/AST/Attr.h

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,12 +1675,11 @@ struct DeclNameRefWithLoc {
16751675
DeclNameLoc Loc;
16761676
};
16771677

1678-
/// Attribute that marks a function as differentiable and optionally specifies
1679-
/// custom associated derivative functions: 'jvp' and 'vjp'.
1678+
/// Attribute that marks a function as differentiable.
16801679
///
16811680
/// Examples:
1682-
/// @differentiable(jvp: jvpFoo where T : FloatingPoint)
1683-
/// @differentiable(wrt: (self, x, y), jvp: jvpFoo)
1681+
/// @differentiable(where T : FloatingPoint)
1682+
/// @differentiable(wrt: (self, x, y))
16841683
class DifferentiableAttr final
16851684
: public DeclAttribute,
16861685
private llvm::TrailingObjects<DifferentiableAttr,
@@ -1696,16 +1695,6 @@ class DifferentiableAttr final
16961695
bool Linear;
16971696
/// The number of parsed differentiability parameters specified in 'wrt:'.
16981697
unsigned NumParsedParameters = 0;
1699-
/// The JVP function.
1700-
Optional<DeclNameRefWithLoc> JVP;
1701-
/// The VJP function.
1702-
Optional<DeclNameRefWithLoc> VJP;
1703-
/// The JVP function (optional), resolved by the type checker if JVP name is
1704-
/// specified.
1705-
FuncDecl *JVPFunction = nullptr;
1706-
/// The VJP function (optional), resolved by the type checker if VJP name is
1707-
/// specified.
1708-
FuncDecl *VJPFunction = nullptr;
17091698
/// The differentiability parameter indices, resolved by the type checker.
17101699
/// The bit stores whether the parameter indices have been computed.
17111700
///
@@ -1724,32 +1713,24 @@ class DifferentiableAttr final
17241713
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
17251714
SourceRange baseRange, bool linear,
17261715
ArrayRef<ParsedAutoDiffParameter> parameters,
1727-
Optional<DeclNameRefWithLoc> jvp,
1728-
Optional<DeclNameRefWithLoc> vjp,
17291716
TrailingWhereClause *clause);
17301717

17311718
explicit DifferentiableAttr(Decl *original, bool implicit, SourceLoc atLoc,
17321719
SourceRange baseRange, bool linear,
17331720
IndexSubset *parameterIndices,
1734-
Optional<DeclNameRefWithLoc> jvp,
1735-
Optional<DeclNameRefWithLoc> vjp,
17361721
GenericSignature derivativeGenericSignature);
17371722

17381723
public:
17391724
static DifferentiableAttr *create(ASTContext &context, bool implicit,
17401725
SourceLoc atLoc, SourceRange baseRange,
17411726
bool linear,
17421727
ArrayRef<ParsedAutoDiffParameter> params,
1743-
Optional<DeclNameRefWithLoc> jvp,
1744-
Optional<DeclNameRefWithLoc> vjp,
17451728
TrailingWhereClause *clause);
17461729

17471730
static DifferentiableAttr *create(AbstractFunctionDecl *original,
17481731
bool implicit, SourceLoc atLoc,
17491732
SourceRange baseRange, bool linear,
17501733
IndexSubset *parameterIndices,
1751-
Optional<DeclNameRefWithLoc> jvp,
1752-
Optional<DeclNameRefWithLoc> vjp,
17531734
GenericSignature derivativeGenSig);
17541735

17551736
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
@@ -1758,16 +1739,6 @@ class DifferentiableAttr final
17581739
/// Should only be used by parsing and deserialization.
17591740
void setOriginalDeclaration(Decl *originalDeclaration);
17601741

1761-
/// Get the optional 'jvp:' function name and location.
1762-
/// Use this instead of `getJVPFunction` to check whether the attribute has a
1763-
/// registered JVP.
1764-
Optional<DeclNameRefWithLoc> getJVP() const { return JVP; }
1765-
1766-
/// Get the optional 'vjp:' function name and location.
1767-
/// Use this instead of `getVJPFunction` to check whether the attribute has a
1768-
/// registered VJP.
1769-
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }
1770-
17711742
private:
17721743
/// Returns true if the given `@differentiable` attribute has been
17731744
/// type-checked.
@@ -1800,21 +1771,14 @@ class DifferentiableAttr final
18001771
DerivativeGenericSignature = derivativeGenSig;
18011772
}
18021773

1803-
FuncDecl *getJVPFunction() const { return JVPFunction; }
1804-
void setJVPFunction(FuncDecl *decl);
1805-
FuncDecl *getVJPFunction() const { return VJPFunction; }
1806-
void setVJPFunction(FuncDecl *decl);
1807-
18081774
/// Get the derivative generic environment for the given `@differentiable`
18091775
/// attribute and original function.
18101776
GenericEnvironment *
18111777
getDerivativeGenericEnvironment(AbstractFunctionDecl *original) const;
18121778

18131779
// Print the attribute to the given stream.
18141780
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1815-
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
1816-
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false,
1817-
bool omitDerivativeFunctions = false) const;
1781+
void print(llvm::raw_ostream &OS, const Decl *D, bool omitWrtClause = false) const;
18181782

18191783
static bool classof(const DeclAttribute *DA) {
18201784
return DA->getKind() == DAK_Differentiable;

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,21 +1582,15 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
15821582
"expected a member name as second parameter in '_implements' attribute", ())
15831583

15841584
// differentiable
1585-
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
1586-
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
1587-
"expected a %0 function name", (StringRef))
15881585
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
15891586
"expected a list of parameters to differentiate with respect to", ())
1590-
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
15911587
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
15921588
"use 'wrt:' to specify parameters to differentiate with respect to", ())
1593-
ERROR(attr_differentiable_expected_label,none,
1594-
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
1595-
"or 'vjp:'", ())
1589+
ERROR(attr_differentiable_expected_label,none,"expected 'wrt:' or 'where'", ())
15961590
ERROR(attr_differentiable_unexpected_argument,none,
15971591
"unexpected argument '%0' in '@differentiable' attribute", (StringRef))
1598-
// TODO(TF-1001): Remove diagnostic when deprecated `jvp:`, `vjp:` are removed.
1599-
WARNING(attr_differentiable_jvp_vjp_deprecated_warning,none,
1592+
// TODO(TF-893): Remove this error after the 0.8 release.
1593+
ERROR(attr_differentiable_jvp_vjp_deprecated_error,none,
16001594
"'jvp:' and 'vjp:' arguments in '@differentiable' attribute are "
16011595
"deprecated; use '@derivative' attribute for derivative registration "
16021596
"instead", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2914,6 +2914,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29142914
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
29152915
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
29162916
"attribute for transpose registration instead", ())
2917+
ERROR(differentiable_attr_void_result,none,
2918+
"cannot differentiate void function %0", (DeclName))
29172919
ERROR(differentiable_attr_overload_not_found,none,
29182920
"%0 does not have expected type %1", (DeclNameRef, Type))
29192921
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
@@ -2938,12 +2940,6 @@ ERROR(differentiable_attr_result_not_differentiable,none,
29382940
ERROR(differentiable_attr_protocol_req_where_clause,none,
29392941
"'@differentiable' attribute on protocol requirement cannot specify "
29402942
"'where' clause", ())
2941-
ERROR(differentiable_attr_protocol_req_assoc_func,none,
2942-
"'@differentiable' attribute on protocol requirement cannot specify "
2943-
"'jvp:' or 'vjp:'", ())
2944-
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
2945-
"'@differentiable' attribute on stored property cannot specify "
2946-
"'jvp:' or 'vjp:'", ())
29472943
ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none,
29482944
"'@differentiable' attribute cannot be declared on class members "
29492945
"returning 'Self'", ())

include/swift/Parse/Parser.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,6 @@ class Parser {
10101010
/// Parse the arguments inside the @differentiable attribute.
10111011
bool parseDifferentiableAttributeArguments(
10121012
bool &linear, SmallVectorImpl<ParsedAutoDiffParameter> &params,
1013-
Optional<DeclNameRefWithLoc> &jvpSpec,
1014-
Optional<DeclNameRefWithLoc> &vjpSpec,
10151013
TrailingWhereClause *&whereClause);
10161014

10171015
/// Parse a differentiability parameters clause, i.e. the 'wrt:' clause in

lib/AST/Attr.cpp

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,9 @@ static std::string getDifferentiationParametersClauseString(
526526
/// Print the arguments of the given `@differentiable` attribute.
527527
/// - If `omitWrtClause` is true, omit printing the `wrt:` differentiation
528528
/// parameters clause.
529-
/// - If `omitDerivativeFunctions` is true, omit printing the JVP/VJP derivative
530-
/// functions.
531529
static void printDifferentiableAttrArguments(
532530
const DifferentiableAttr *attr, ASTPrinter &printer, PrintOptions Options,
533-
const Decl *D, bool omitWrtClause = false,
534-
bool omitDerivativeFunctions = false) {
531+
const Decl *D, bool omitWrtClause = false) {
535532
assert(D);
536533
// Create a temporary string for the attribute argument text.
537534
std::string attrArgText;
@@ -574,19 +571,6 @@ static void printDifferentiableAttrArguments(
574571
stream << diffParamsString;
575572
}
576573
}
577-
// Print derivative function names, unless they are to be omitted.
578-
if (!omitDerivativeFunctions) {
579-
// Print jvp function name, if specified.
580-
if (auto jvp = attr->getJVP()) {
581-
printCommaIfNecessary();
582-
stream << "jvp: " << jvp->Name;
583-
}
584-
// Print vjp function name, if specified.
585-
if (auto vjp = attr->getVJP()) {
586-
printCommaIfNecessary();
587-
stream << "vjp: " << vjp->Name;
588-
}
589-
}
590574
// Print 'where' clause, if any.
591575
// First, filter out requirements satisfied by the original function's
592576
// generic signature. They should not be printed.
@@ -1616,12 +1600,9 @@ SPIAccessControlAttr::create(ASTContext &context,
16161600
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
16171601
SourceRange baseRange, bool linear,
16181602
ArrayRef<ParsedAutoDiffParameter> params,
1619-
Optional<DeclNameRefWithLoc> jvp,
1620-
Optional<DeclNameRefWithLoc> vjp,
16211603
TrailingWhereClause *clause)
16221604
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1623-
Linear(linear), NumParsedParameters(params.size()), JVP(std::move(jvp)),
1624-
VJP(std::move(vjp)), WhereClause(clause) {
1605+
Linear(linear), NumParsedParameters(params.size()), WhereClause(clause) {
16251606
std::copy(params.begin(), params.end(),
16261607
getTrailingObjects<ParsedAutoDiffParameter>());
16271608
}
@@ -1630,12 +1611,9 @@ DifferentiableAttr::DifferentiableAttr(Decl *original, bool implicit,
16301611
SourceLoc atLoc, SourceRange baseRange,
16311612
bool linear,
16321613
IndexSubset *parameterIndices,
1633-
Optional<DeclNameRefWithLoc> jvp,
1634-
Optional<DeclNameRefWithLoc> vjp,
16351614
GenericSignature derivativeGenSig)
16361615
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, implicit),
1637-
OriginalDeclaration(original), Linear(linear), JVP(std::move(jvp)),
1638-
VJP(std::move(vjp)) {
1616+
OriginalDeclaration(original), Linear(linear) {
16391617
setParameterIndices(parameterIndices);
16401618
setDerivativeGenericSignature(derivativeGenSig);
16411619
}
@@ -1645,29 +1623,23 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
16451623
SourceLoc atLoc, SourceRange baseRange,
16461624
bool linear,
16471625
ArrayRef<ParsedAutoDiffParameter> parameters,
1648-
Optional<DeclNameRefWithLoc> jvp,
1649-
Optional<DeclNameRefWithLoc> vjp,
16501626
TrailingWhereClause *clause) {
16511627
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
16521628
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
16531629
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
1654-
parameters, std::move(jvp),
1655-
std::move(vjp), clause);
1630+
parameters, clause);
16561631
}
16571632

16581633
DifferentiableAttr *
16591634
DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
16601635
SourceLoc atLoc, SourceRange baseRange, bool linear,
16611636
IndexSubset *parameterIndices,
1662-
Optional<DeclNameRefWithLoc> jvp,
1663-
Optional<DeclNameRefWithLoc> vjp,
16641637
GenericSignature derivativeGenSig) {
16651638
auto &ctx = original->getASTContext();
16661639
void *mem = ctx.Allocate(sizeof(DifferentiableAttr),
16671640
alignof(DifferentiableAttr));
16681641
return new (mem) DifferentiableAttr(original, implicit, atLoc, baseRange,
1669-
linear, parameterIndices, std::move(jvp),
1670-
std::move(vjp), derivativeGenSig);
1642+
linear, parameterIndices, derivativeGenSig);
16711643
}
16721644

16731645
void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
@@ -1701,18 +1673,6 @@ void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
17011673
std::move(paramIndices));
17021674
}
17031675

1704-
void DifferentiableAttr::setJVPFunction(FuncDecl *decl) {
1705-
JVPFunction = decl;
1706-
if (decl && !JVP)
1707-
JVP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())};
1708-
}
1709-
1710-
void DifferentiableAttr::setVJPFunction(FuncDecl *decl) {
1711-
VJPFunction = decl;
1712-
if (decl && !VJP)
1713-
VJP = {decl->createNameRef(), DeclNameLoc(decl->getNameLoc())};
1714-
}
1715-
17161676
GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
17171677
AbstractFunctionDecl *original) const {
17181678
GenericEnvironment *derivativeGenEnv = original->getGenericEnvironment();
@@ -1722,12 +1682,10 @@ GenericEnvironment *DifferentiableAttr::getDerivativeGenericEnvironment(
17221682
}
17231683

17241684
void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
1725-
bool omitWrtClause,
1726-
bool omitDerivativeFunctions) const {
1685+
bool omitWrtClause) const {
17271686
StreamPrinter P(OS);
17281687
P << "@" << getAttrName();
1729-
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause,
1730-
omitDerivativeFunctions);
1688+
printDifferentiableAttrArguments(this, P, PrintOptions(), D, omitWrtClause);
17311689
}
17321690

17331691
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,

0 commit comments

Comments
 (0)