@@ -67,6 +67,31 @@ func slope2(_ x: Float) -> Float {
67
67
return 2 * x
68
68
}
69
69
70
+ @differentiable ( wrt: y) // ok
71
+ func two( x: Float , y: Float ) -> Float {
72
+ return x + y
73
+ }
74
+
75
+ @differentiable ( wrt: ( x, y) ) // ok
76
+ func two( x: Float , y: Float ) -> Float {
77
+ return x + y
78
+ }
79
+
80
+ @differentiable ( wrt: ( 0 , y) ) // ok
81
+ func two( x: Float , y: Float ) -> Float {
82
+ return x + y
83
+ }
84
+
85
+ @differentiable ( wrt: ( x, 1 ) ) // ok
86
+ func two( x: Float , y: Float ) -> Float {
87
+ return x + y
88
+ }
89
+
90
+ @differentiable ( wrt: ( 0 , 1 ) ) // ok
91
+ func two( x: Float , y: Float ) -> Float {
92
+ return x + y
93
+ }
94
+
70
95
/// Bad
71
96
72
97
@differentiable ( 3 ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
@@ -99,6 +124,21 @@ func bar(_ x: Float, _ y: Float) -> Float {
99
124
return 1 + x
100
125
}
101
126
127
+ @differentiable ( wrt: 0 , 1 ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
128
+ func two( x: Float , y: Float ) -> Float {
129
+ return x + y
130
+ }
131
+
132
+ @differentiable ( wrt: 0 , y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
133
+ func two( x: Float , y: Float ) -> Float {
134
+ return x + y
135
+ }
136
+
137
+ @differentiable ( wrt: 0 , ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
138
+ func two( x: Float , y: Float ) -> Float {
139
+ return x + y
140
+ }
141
+
102
142
@differentiable ( vjp: foo ( _: _: ) // expected-error {{expected ')' in 'differentiable' attribute}}
103
143
func bar( _ x: Float , _: Float ) -> Float {
104
144
return 1 + x
0 commit comments