Skip to content

Commit 71ba224

Browse files
authored
[AutoDiff] Allow Referencing Parameters as indices in wrt: of @differentiating & @differentiable (swiftlang#25594)
- can specify arguments you want to differentiate using the order it appears in the original function
1 parent 8d56a06 commit 71ba224

File tree

11 files changed

+215
-17
lines changed

11 files changed

+215
-17
lines changed

include/swift/AST/Attr.def

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,10 @@ DECL_ATTR(differentiating, Differentiating,
428428
OnFunc | LongAttribute | AllowMultipleAttributes |
429429
NotSerialized, 90)
430430
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
431-
OnAccessor | OnFunc | OnConstructor | OnSubscript,
432-
/* Not serialized */ 91)
431+
OnAccessor | OnFunc | OnConstructor | OnSubscript,
432+
/* Not serialized */ 91)
433433
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
434-
OnVar, 92)
434+
OnVar, 92)
435435

436436
#undef TYPE_ATTR
437437
#undef DECL_ATTR_ALIAS

include/swift/AST/AutoDiff.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,36 @@ enum class DifferentiabilityKind: uint8_t {
3232

3333
class ParsedAutoDiffParameter {
3434
public:
35-
enum class Kind { Named, Self };
35+
enum class Kind { Named, Ordered, Self };
3636

3737
private:
3838
SourceLoc Loc;
3939
Kind Kind;
4040
union Value {
41-
struct { Identifier Name; }; // Index
41+
struct { Identifier Name; }; // Named
42+
struct { unsigned Index; }; // Ordered
4243
struct {}; // Self
4344
Value(Identifier name) : Name(name) {}
45+
Value(unsigned index) : Index(index) {}
4446
Value() {}
4547
} V;
4648

4749
public:
4850
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, Value value)
4951
: Loc(loc), Kind(kind), V(value) {}
52+
53+
ParsedAutoDiffParameter(SourceLoc loc, enum Kind kind, unsigned index)
54+
: Loc(loc), Kind(kind), V(index) {}
5055

5156
static ParsedAutoDiffParameter getNamedParameter(SourceLoc loc,
5257
Identifier name) {
5358
return { loc, Kind::Named, name };
5459
}
60+
61+
static ParsedAutoDiffParameter getOrderedParameter(SourceLoc loc,
62+
unsigned index) {
63+
return { loc, Kind::Ordered, index };
64+
}
5565

5666
static ParsedAutoDiffParameter getSelfParameter(SourceLoc loc) {
5767
return { loc, Kind::Self, {} };
@@ -61,6 +71,10 @@ class ParsedAutoDiffParameter {
6171
assert(Kind == Kind::Named);
6272
return V.Name;
6373
}
74+
75+
unsigned getIndex() const {
76+
return V.Index;
77+
}
6478

6579
enum Kind getKind() const {
6680
return Kind;

include/swift/AST/DiagnosticsParse.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1516,7 +1516,8 @@ ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
15161516
ERROR(expected_colon_after_label,PointsToFirstBadToken,
15171517
"expected a colon ':' after '%0'", (StringRef))
15181518
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
1519-
"expected a parameter, which can be a function parameter name or 'self'",
1519+
"expected a parameter, which can be a function parameter name, "
1520+
"parameter index, or 'self'",
15201521
())
15211522

15221523
// [differentiable ...] (sil-decl attr)

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2815,7 +2815,9 @@ ERROR(diff_params_clause_self_instance_method_only,none,
28152815
ERROR(diff_params_clause_self_must_be_first,none,
28162816
"'self' parameter must come first in the parameter list", ())
28172817
ERROR(diff_params_clause_params_not_original_order,none,
2818-
"parameter names must be specified in original order", ())
2818+
"parameters must be specified in original order", ())
2819+
ERROR(diff_params_clause_param_index_out_of_range,none,
2820+
"parameter index is larger than total number of parameters", ())
28192821
ERROR(diff_params_clause_no_inferred_parameters,PointsToFirstBadToken,
28202822
"no differentiation parameters could be inferred; must differentiate "
28212823
"with respect to at least one parameter conforming to 'Differentiable'",

include/swift/Parse/Parser.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,6 @@ class Parser {
961961
bool parseDifferentiationParametersClause(
962962
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
963963

964-
/// SWIFT_ENABLE_TENSORFLOW
965964
/// Parse the @differentiating attribute.
966965
ParserResult<DifferentiatingAttr>
967966
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);

lib/AST/Attr.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ static std::string getDifferentiationParametersClauseString(
374374
case ParsedAutoDiffParameter::Kind::Self:
375375
printer << "self";
376376
break;
377+
case ParsedAutoDiffParameter::Kind::Ordered:
378+
auto *paramList = function->getParameters();
379+
assert(param.getIndex() <= paramList->size() &&
380+
"wrt parameter is out of range");
381+
auto *funcParam = paramList->get(param.getIndex());
382+
printer << funcParam->getNameStr();
383+
break;
377384
}
378385
}, [&] { printer << ", "; });
379386
if (parsedParams.size() > 1)

lib/Parse/ParseDecl.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,17 @@ bool Parser::parseDifferentiationParametersClause(
892892
paramLoc, paramName));
893893
break;
894894
}
895+
case tok::integer_literal: {
896+
unsigned paramNum;
897+
if (parseUnsignedInteger(
898+
paramNum, paramLoc,
899+
diag::diff_params_clause_expected_parameter))
900+
return true;
901+
902+
params.push_back(ParsedAutoDiffParameter::getOrderedParameter(
903+
paramLoc, paramNum));
904+
break;
905+
}
895906
case tok::kw_self: {
896907
paramLoc = consumeToken(tok::kw_self);
897908
params.push_back(ParsedAutoDiffParameter::getSelfParameter(paramLoc));
@@ -1960,15 +1971,15 @@ bool Parser::parseNewDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc,
19601971
break;
19611972
}
19621973

1963-
/// SWIFT_ENABLE_TENSORFLOW
1974+
// SWIFT_ENABLE_TENSORFLOW
19641975
case DAK_Differentiable: {
19651976
auto Attr = parseDifferentiableAttribute(AtLoc, Loc);
19661977
if (Attr.isNonNull())
19671978
Attributes.add(Attr.get());
19681979
break;
19691980
}
19701981

1971-
/// SWIFT_ENABLE_TENSORFLOW
1982+
// SWIFT_ENABLE_TENSORFLOW
19721983
case DAK_Differentiating: {
19731984
auto Attr = parseDifferentiatingAttribute(AtLoc, Loc);
19741985
if (Attr.isNonNull())

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2856,6 +2856,7 @@ static AutoDiffParameterIndices *computeDifferentiationParameters(
28562856
auto *functionType = function->getInterfaceType()->eraseDynamicSelfType()
28572857
->castTo<AnyFunctionType>();
28582858
auto &params = *function->getParameters();
2859+
auto numParams = function->getParameters()->size();
28592860
auto isInstanceMethod = function->isInstanceMember();
28602861

28612862
// Diagnose if function has no parameters.
@@ -2934,6 +2935,22 @@ static AutoDiffParameterIndices *computeDifferentiationParameters(
29342935
builder.setParameter(builder.size() - 1);
29352936
break;
29362937
}
2938+
case ParsedAutoDiffParameter::Kind::Ordered: {
2939+
auto index = parsedWrtParams[i].getIndex();
2940+
if (index >= numParams) {
2941+
TC.diagnose(paramLoc, diag::diff_params_clause_param_index_out_of_range);
2942+
return nullptr;
2943+
}
2944+
// Parameter names must be specified in the original order.
2945+
if ((int)index <= lastIndex) {
2946+
TC.diagnose(paramLoc,
2947+
diag::diff_params_clause_params_not_original_order);
2948+
return nullptr;
2949+
}
2950+
builder.setParameter(index);
2951+
lastIndex = index;
2952+
break;
2953+
}
29372954
}
29382955
}
29392956
return builder.build(TC.Context);

test/AutoDiff/differentiable_attr_parse.swift

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ func bar(_ x: Float, _: Float) -> Float {
3737
return 1 + x
3838
}
3939

40+
@differentiable(wrt: (x)) // okay
41+
func bar(_ x: Float, _: Float) -> Float {
42+
return 1 + x
43+
}
44+
4045
@differentiable(wrt: self) // okay
4146
func bar(_ x: Float, _: Float) -> Float {
4247
return 1 + x
@@ -67,6 +72,31 @@ func slope2(_ x: Float) -> Float {
6772
return 2 * x
6873
}
6974

75+
@differentiable(wrt: y) // ok
76+
func two(x: Float, y: Float) -> Float {
77+
return x + y
78+
}
79+
80+
@differentiable(wrt: (x, y)) // ok
81+
func two(x: Float, y: Float) -> Float {
82+
return x + y
83+
}
84+
85+
@differentiable(wrt: (0, y)) // ok
86+
func two(x: Float, y: Float) -> Float {
87+
return x + y
88+
}
89+
90+
@differentiable(wrt: (x, 1)) // ok
91+
func two(x: Float, y: Float) -> Float {
92+
return x + y
93+
}
94+
95+
@differentiable(wrt: (0, 1)) // ok
96+
func two(x: Float, y: Float) -> Float {
97+
return x + y
98+
}
99+
70100
/// Bad
71101

72102
@differentiable(3) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
@@ -89,16 +119,26 @@ func bar(_ x: Float, _: Float) -> Float {
89119
return 1 + x
90120
}
91121

92-
@differentiable(wrt: (1), vjp: foo(_:_:)) // expected-error {{expected a parameter, which can be a function parameter name or 'self'}}
93-
func bar(_ x: Float, _: Float) -> Float {
94-
return 1 + x
95-
}
96-
97122
@differentiable(wrt: x, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
98123
func bar(_ x: Float, _ y: Float) -> Float {
99124
return 1 + x
100125
}
101126

127+
@differentiable(wrt: 0, 1) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
128+
func two(x: Float, y: Float) -> Float {
129+
return x + y
130+
}
131+
132+
@differentiable(wrt: 0, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
133+
func two(x: Float, y: Float) -> Float {
134+
return x + y
135+
}
136+
137+
@differentiable(wrt: 0,) // expected-error {{unexpected ',' separator}}
138+
func two(x: Float, y: Float) -> Float {
139+
return x + y
140+
}
141+
102142
@differentiable(vjp: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
103143
func bar(_ x: Float, _: Float) -> Float {
104144
return 1 + x

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func jvpUnknownParam(x: Float) -> Float {
192192
return x
193193
}
194194

195-
// expected-error @+1 {{parameter names must be specified in original order}}
195+
// expected-error @+1 {{parameters must be specified in original order}}
196196
@differentiable(wrt: (y, x))
197197
func jvpParamOrderNotIncreasing(x: Float, y: Float) -> Float {
198198
return x * y
@@ -769,3 +769,78 @@ func slope2(_ x: Float) -> Float {
769769
func slope3(_ x: Float) -> Float {
770770
return 3 * x
771771
}
772+
773+
// Index based 'wrt:'
774+
775+
struct NumberWrtStruct: Differentiable {
776+
var a, b: Float
777+
778+
@differentiable(wrt: 0) // ok
779+
@differentiable(wrt: 1) // ok
780+
func foo1(_ x: Float, _ y: Float) -> Float {
781+
return a*x + b*y
782+
}
783+
784+
@differentiable(wrt: -1) // expected-error {{expected a parameter, which can be a function parameter name, parameter index, or 'self'}}
785+
@differentiable(wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
786+
func foo2(_ x: Float, _ y: Float) -> Float {
787+
return a*x + b*y
788+
}
789+
790+
@differentiable(wrt: (x, 1)) // ok
791+
@differentiable(wrt: (0)) // ok
792+
static func staticFoo1(_ x: Float, _ y: Float) -> Float {
793+
return x + y
794+
}
795+
796+
@differentiable(wrt: (1, 1)) // expected-error {{parameters must be specified in original order}}
797+
@differentiable(wrt: (2)) // expected-error {{parameter index is larger than total number of parameters}}
798+
static func staticFoo2(_ x: Float, _ y: Float) -> Float {
799+
return x + y
800+
}
801+
}
802+
803+
@differentiable(wrt: y) // ok
804+
func two1(x: Float, y: Float) -> Float {
805+
return x + y
806+
}
807+
808+
@differentiable(wrt: (x, y)) // ok
809+
func two2(x: Float, y: Float) -> Float {
810+
return x + y
811+
}
812+
813+
@differentiable(wrt: (0, y)) // ok
814+
func two3(x: Float, y: Float) -> Float {
815+
return x + y
816+
}
817+
818+
@differentiable(wrt: (x, 1)) // ok
819+
func two4(x: Float, y: Float) -> Float {
820+
return x + y
821+
}
822+
823+
@differentiable(wrt: (0, 1)) // ok
824+
func two5(x: Float, y: Float) -> Float {
825+
return x + y
826+
}
827+
828+
@differentiable(wrt: 2) // expected-error {{parameter index is larger than total number of parameters}}
829+
func two6(x: Float, y: Float) -> Float {
830+
return x + y
831+
}
832+
833+
@differentiable(wrt: (1, 0)) // expected-error {{parameters must be specified in original order}}
834+
func two7(x: Float, y: Float) -> Float {
835+
return x + y
836+
}
837+
838+
@differentiable(wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
839+
func two8(x: Float, y: Float) -> Float {
840+
return x + y
841+
}
842+
843+
@differentiable(wrt: (y, 0)) // expected-error {{parameters must be specified in original order}}
844+
func two9(x: Float, y: Float) -> Float {
845+
return x + y
846+
}

test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Flo
8484
func vjpUnknownParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) {
8585
return (x + y, { $0 })
8686
}
87-
// expected-error @+1 {{parameter names must be specified in original order}}
87+
// expected-error @+1 {{parameters must be specified in original order}}
8888
@differentiating(add, wrt: (y, x))
8989
func vjpParamOrderNotIncreasing(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
9090
return (x + y, { ($0, $0) })
@@ -315,3 +315,35 @@ func f(_ x: PropertyDiff) -> Float {
315315

316316
let a = gradient(at: PropertyDiff(), in: f)
317317
print(a)
318+
319+
// Index based 'wrt:'
320+
321+
func add2(x: Float, y: Float) -> Float {
322+
return x + y
323+
}
324+
325+
@differentiating(add2, wrt: (0, y)) // ok
326+
func two3(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
327+
return (x + y, { ($0, $0) })
328+
}
329+
330+
@differentiating(add2, wrt: (1)) // ok
331+
func two4(x: Float, y: Float) -> (value: Float, pullback: (Float) -> Float) {
332+
return (x + y, { $0 })
333+
}
334+
335+
336+
@differentiating(add2, wrt: 2) // expected-error {{parameter index is larger than total number of parameters}}
337+
func two5(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
338+
return (x + y, { ($0, $0) })
339+
}
340+
341+
@differentiating(add2, wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
342+
func two6(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
343+
return (x + y, { ($0, $0) })
344+
}
345+
346+
@differentiating(add2, wrt: (1, 0)) // expected-error {{parameters must be specified in original order}}
347+
func two7(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
348+
return (x + y, { ($0, $0) })
349+
}

0 commit comments

Comments
 (0)