Skip to content

Commit afe9307

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Change @differentiable syntax to use parameter names. (#22000)
* [AutoDiff] Change `@differentiable` syntax to use parameter names. Use parameter names instead of parameter indices in `@differentiable` attribute `wrt` list. The old syntax used parameter indices with a prefix dot, e.g. `wrt: (.0, .1)`. The new syntax is more readable and understandable when good parameter names are chosen. New syntax: ``` @differentiable(wrt: (self, input)) func applied(to input: Float) -> Output { ... } ``` The internal, lowered representation of parameter indices is unchanged. This patch affects only AST attribute parsing/print. Update `wrt` clause usages in stdlib and tests. Todos: - Support parameter name `self` in `wrt` clause, probably via escape backticks: ``self``. - Low priority: add fixits/specific diagnostics for old parameter index style. * Address comments from @rxwei. - Don't allocate auxiliary vector. - Update parameter index related diagnostics. * Support single `wrt` parameter without parentheses. Single `wrt` parameters don't require parenthesis. The following are now valid: ``` @differentiable(wrt: x) @differentiable(wrt: self) ``` Multiple parameters still require parentheses: `@differentiable(wrt: x, y)`. * Rewrite stdlib usages of `@differentiable`, update libSyntax support. Rewrite stdlib usages of `@differentiable` with one differentiation parameter: `wrt: (self)` -> `wrt: self`. Note: printing still always prints parentheses around the differentiation parameter list. * Fix `test/Syntax/round_trip_parse_gen.swift`.
1 parent da4fd4d commit afe9307

25 files changed

+203
-160
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,34 +25,34 @@ namespace swift {
2525

2626
class ParsedAutoDiffParameter {
2727
public:
28-
enum class Kind { Index, Self };
28+
enum class Kind { Named, Self };
2929

3030
private:
3131
SourceLoc Loc;
3232
Kind Kind;
3333
union Value {
34-
struct { unsigned Index; }; // Index
34+
struct { Identifier Name; }; // Index
3535
struct {}; // Self
36-
Value(unsigned index) : Index(index) {}
36+
Value(Identifier name) : Name(name) {}
3737
Value() {}
3838
} V;
3939

4040
public:
4141
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
4242
: Loc(loc), Kind(kind), V(value) {}
4343

44-
static ParsedAutoDiffParameter getIndexParameter(SourceLoc loc,
45-
unsigned index) {
46-
return { loc, Kind::Index, index };
44+
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
45+
Identifier name) {
46+
return { loc, Kind::Named, name };
4747
}
4848

4949
static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {
5050
return { loc, Kind::Self, {} };
5151
}
5252

53-
unsigned getIndex() const {
54-
assert(Kind == Kind::Index);
55-
return V.Index;
53+
Identifier getName() const {
54+
assert(Kind == Kind::Named);
55+
return V.Name;
5656
}
5757

5858
enum Kind getKind() const {
@@ -64,8 +64,8 @@ class ParsedAutoDiffParameter {
6464
}
6565

6666
bool isEqual(const ParsedAutoDiffParameter &other) const {
67-
if (getKind() == other.getKind() && getKind() == Kind::Index)
68-
return getIndex() == other.getIndex();
67+
if (getKind() == other.getKind() && getKind() == Kind::Named)
68+
return getName() == other.getName();
6969
return getKind() == other.getKind() && getKind() == Kind::Self;
7070
}
7171
};

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,17 +1466,11 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
14661466
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
14671467
"expected a qualified %0 function", (StringRef))
14681468
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
1469-
"expected a list of parameters to differentiate with respect to, e.g. (.0, w, b)", ())
1469+
"expected a list of parameters to differentiate with respect to", ())
14701470
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
14711471
"use 'wrt:' to specify parameters to differentiate with respect to", ())
14721472
ERROR(attr_differentiable_expected_parameter,PointsToFirstBadToken,
1473-
"expected a parameter, which can be the index of a function parameter with a leading dot (e.g. '.0'), or 'self'", ())
1474-
ERROR(attr_differentiable_parameter_index_must_be_positive_integer,PointsToFirstBadToken,
1475-
"parameter index must be a positive integer", ())
1476-
ERROR(attr_differentiable_nondifferentiable_function,none,
1477-
"function '%0' is not differentiable", (StringRef))
1478-
ERROR(attr_differentiable_duplicate_config_option,PointsToFirstBadToken,
1479-
"duplicated gradient configuration option '%0'", (StringRef))
1473+
"expected a parameter, which can be a function parameter name or 'self'", ())
14801474
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
14811475
"missing label '%0:' in '@differentiable' attribute", (StringRef))
14821476
ERROR(attr_differentiable_expected_colon_after_label,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,10 +2665,10 @@ ERROR(differentiable_attr_overload_not_found,none,
26652665
"%0 does not have expected type %1", (DeclName, Type))
26662666
ERROR(differentiable_attr_wrt_self_must_be_first,none,
26672667
"'self' parameter must come first in the parameter list", ())
2668-
ERROR(differentiable_attr_wrt_indices_must_be_ascending,none,
2669-
"parameter indices must be ascending", ())
2670-
ERROR(differentiable_attr_wrt_index_out_of_bounds,none,
2671-
"parameter index out of bounds", ())
2668+
ERROR(differentiable_attr_wrt_names_not_original_order,none,
2669+
"parameter names must be specified in original order", ())
2670+
ERROR(differentiable_attr_wrt_name_unknown,none,
2671+
"unknown parameter name %0", (Identifier))
26722672
ERROR(differentiable_attr_wrt_self_instance_method_only,none,
26732673
"'self' parameter is only applicable to instance methods", ())
26742674
ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none,

lib/AST/Attr.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "swift/AST/GenericEnvironment.h"
2222
#include "swift/AST/Module.h"
2323
#include "swift/AST/Types.h"
24+
// SWIFT_ENABLE_TENSORFLOW
25+
#include "swift/AST/ParameterList.h"
2426
#include "swift/Basic/Defer.h"
2527
#include "llvm/ADT/SmallString.h"
2628
#include "llvm/Support/raw_ostream.h"
@@ -580,16 +582,16 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
580582
if (isProperty || (isMethod && index == indices->parameters.size() - 1))
581583
Printer << "self";
582584
else
583-
Printer << "." << index;
585+
Printer << original->getParameters()->get(index)->getName().str();
584586
}, [&] { Printer << ", "; });
585587
Printer << ")";
586588
} else if (!parsedParams.empty()) {
587589
printCommaIfNecessary();
588590
Printer << "wrt: (";
589591
interleave(parsedParams, [&](const ParsedAutoDiffParameter &param) {
590592
switch (param.getKind()) {
591-
case ParsedAutoDiffParameter::Kind::Index:
592-
Printer << '.' << param.getIndex();
593+
case ParsedAutoDiffParameter::Kind::Named:
594+
Printer << '.' << param.getName();
593595
break;
594596
case ParsedAutoDiffParameter::Kind::Self:
595597
Printer << "self";

lib/Parse/ParseDecl.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -911,37 +911,29 @@ bool Parser::parseDifferentiableAttributeArguments(
911911
return errorAndSkipToEnd();
912912
}
913913
if (Tok.is(tok::identifier) && Tok.getText() == "wrt") {
914-
SyntaxParsingContext DiffParamsContext(
915-
SyntaxContext, SyntaxKind::DifferentiableAttributeDiffParams);
914+
SyntaxParsingContext DiffParamsClauseContext(
915+
SyntaxContext, SyntaxKind::DifferentiableAttributeDiffParamsClause);
916916
consumeToken(tok::identifier);
917917
if (!consumeIf(tok::colon)) {
918918
diagnose(Tok, diag::attr_differentiable_expected_colon_after_label,
919919
"wrt");
920920
return errorAndSkipToEnd();
921921
}
922-
SourceLoc leftLoc;
923-
if (parseToken(tok::l_paren, leftLoc,
924-
diag::attr_differentiable_expected_parameter_list)) {
925-
return errorAndSkipToEnd();
926-
}
927922

928923
// Function that parses a parameter into `params`. Returns true if error
929924
// occurred.
930-
auto parseParam = [&]() -> bool {
925+
auto parseParam = [&](bool parseTrailingComma = true) -> bool {
931926
SyntaxParsingContext DiffParamContext(
932927
SyntaxContext, SyntaxKind::DifferentiableAttributeDiffParam);
933928
SourceLoc paramLoc;
934929
switch (Tok.getKind()) {
935-
case tok::period_prefix: {
936-
SyntaxParsingContext IndexParamContext(
937-
SyntaxContext, SyntaxKind::DifferentiationIndexParam);
938-
consumeToken(tok::period_prefix);
939-
unsigned index;
940-
if (parseUnsignedInteger(index, paramLoc,
941-
diag::attr_differentiable_expected_parameter))
930+
case tok::identifier: {
931+
Identifier paramName;
932+
if (parseIdentifier(paramName, paramLoc,
933+
diag::attr_differentiable_expected_parameter))
942934
return true;
943-
params.push_back(
944-
ParsedAutoDiffParameter::getIndexParameter(paramLoc, index));
935+
params.push_back(ParsedAutoDiffParameter::getNamedParameter(
936+
paramLoc, paramName));
945937
break;
946938
}
947939
case tok::kw_self: {
@@ -953,24 +945,34 @@ bool Parser::parseDifferentiableAttributeArguments(
953945
diagnose(Tok, diag::attr_differentiable_expected_parameter);
954946
return true;
955947
}
956-
if (Tok.isNot(tok::r_paren))
948+
if (parseTrailingComma && Tok.isNot(tok::r_paren))
957949
return parseToken(tok::comma, diag::attr_expected_comma, AttrName,
958950
/*isDeclModifier=*/false);
959951
return false;
960952
};
961953

962-
// Parse first parameter. At least one is required.
963-
if (parseParam())
964-
return errorAndSkipToEnd(2);
965-
// Parse remaining parameters until ')'.
966-
while (Tok.isNot(tok::r_paren))
954+
// Parse opening '(' of the parameter list.
955+
if (Tok.is(tok::l_paren)) {
956+
SyntaxParsingContext DiffParamsContext(
957+
SyntaxContext, SyntaxKind::DifferentiableAttributeDiffParams);
958+
consumeToken(tok::l_paren);
959+
// Parse first parameter. At least one is required.
967960
if (parseParam())
968961
return errorAndSkipToEnd(2);
969-
970-
SyntaxContext->collectNodesInPlace(
971-
SyntaxKind::DifferentiableAttributeDiffParamList);
972-
// Parse closing ')' of the parameter list.
973-
consumeToken(tok::r_paren);
962+
// Parse remaining parameters until ')'.
963+
while (Tok.isNot(tok::r_paren))
964+
if (parseParam())
965+
return errorAndSkipToEnd(2);
966+
SyntaxContext->collectNodesInPlace(
967+
SyntaxKind::DifferentiableAttributeDiffParamList);
968+
// Parse closing ')' of the parameter list.
969+
consumeToken(tok::r_paren);
970+
}
971+
// If no opening '(' for parameter list, parse a single parameter.
972+
else {
973+
if (parseParam(/*parseTrailingComma*/ false))
974+
return errorAndSkipToEnd();
975+
}
974976
// If no trailing comma or 'where' clause, terminate parsing arguments.
975977
if (Tok.isNot(tok::comma) && Tok.isNot(tok::kw_where))
976978
return false;

lib/Sema/TypeCheckAttr.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,17 +2380,22 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23802380
for (unsigned i : indices(parsedWrtParams)) {
23812381
auto paramLoc = parsedWrtParams[i].getLoc();
23822382
switch (parsedWrtParams[i].getKind()) {
2383-
case ParsedAutoDiffParameter::Kind::Index: {
2384-
unsigned index = parsedWrtParams[i].getIndex();
2385-
if ((int)index <= lastIndex) {
2386-
TC.diagnose(paramLoc,
2387-
diag::differentiable_attr_wrt_indices_must_be_ascending);
2383+
case ParsedAutoDiffParameter::Kind::Named: {
2384+
auto nameIter =
2385+
llvm::find_if(originalParams.getArray(), [&](ParamDecl *param) {
2386+
return param->getName() == parsedWrtParams[i].getName();
2387+
});
2388+
// Parameter name must exist.
2389+
if (nameIter == originalParams.end()) {
2390+
TC.diagnose(paramLoc, diag::differentiable_attr_wrt_name_unknown,
2391+
parsedWrtParams[i].getName());
23882392
return;
23892393
}
2390-
// Parameter index cannot exceed bounds.
2391-
if (index >= originalParams.size()) {
2394+
// Parameter names must be specified in the original order.
2395+
unsigned index = std::distance(originalParams.begin(), nameIter);
2396+
if ((int)index <= lastIndex) {
23922397
TC.diagnose(paramLoc,
2393-
diag::differentiable_attr_wrt_index_out_of_bounds);
2398+
diag::differentiable_attr_wrt_names_not_original_order);
23942399
return;
23952400
}
23962401
autoDiffParameterIndicesBuilder.setParameter(index);

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ extension Tensor where Scalar : Differentiable & FloatingPoint {
617617
/// TensorFlow builtin conv2d gradient helper for the input.
618618
@inlinable
619619
@differentiable(
620-
wrt: (.1, .2),
620+
wrt: (filter, backpropOutput),
621621
vjp: _vjpTFConv2DBackpropInput(_:_:_:_:_:)
622622
)
623623
func _TFConv2DBackpropInput(
@@ -638,7 +638,7 @@ extension Tensor where Scalar : Differentiable & FloatingPoint {
638638
/// TensorFlow builtin conv2d gradient helper for the filter.
639639
@inlinable
640640
@differentiable(
641-
wrt: (.0, .2),
641+
wrt: (input, backpropOutput),
642642
vjp: _vjpTFConv2DBackpropFilter(_:_:_:_:_:)
643643
)
644644
func _TFConv2DBackpropFilter(

0 commit comments

Comments
 (0)