@@ -57,24 +57,39 @@ public func squareRoot() -> Self {
57
57
return lhs
58
58
}
59
59
60
+ @differentiable ( linear) // okay
61
+ func identity( _ x: Float ) -> Float {
62
+ return x
63
+ }
64
+
65
+ @differentiable ( linear, wrt: x) // okay
66
+ func slope2( _ x: Float ) -> Float {
67
+ return 2 * x
68
+ }
69
+
70
+ @differentiable ( linear, wrt: x, vjp: const3) // okay
71
+ func slope3( _ x: Float ) -> Float {
72
+ return 3 * x
73
+ }
74
+
60
75
/// Bad
61
76
62
- @differentiable ( 3 ) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
77
+ @differentiable ( 3 ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
63
78
func bar( _ x: Float , _: Float ) -> Float {
64
79
return 1 + x
65
80
}
66
81
67
- @differentiable ( foo ( _: _: ) ) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
82
+ @differentiable ( foo ( _: _: ) ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
68
83
func bar( _ x: Float , _: Float ) -> Float {
69
84
return 1 + x
70
85
}
71
86
72
- @differentiable ( vjp: foo ( _: _: ) , 3 ) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
87
+ @differentiable ( vjp: foo ( _: _: ) , 3 ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
73
88
func bar( _ x: Float , _: Float ) -> Float {
74
89
return 1 + x
75
90
}
76
91
77
- @differentiable ( wrt: ( x) , foo ( _: _: ) ) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
92
+ @differentiable ( wrt: ( x) , foo ( _: _: ) ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
78
93
func bar( _ x: Float , _: Float ) -> Float {
79
94
return 1 + x
80
95
}
@@ -84,7 +99,7 @@ func bar(_ x: Float, _: Float) -> Float {
84
99
return 1 + x
85
100
}
86
101
87
- @differentiable ( wrt: x, y) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
102
+ @differentiable ( wrt: x, y) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
88
103
func bar( _ x: Float , _ y: Float ) -> Float {
89
104
return 1 + x
90
105
}
@@ -99,7 +114,7 @@ func bar<T : Numeric>(_ x: T, _: T) -> T {
99
114
return 1 + x
100
115
}
101
116
102
- @differentiable ( , ) // expected-error {{expected a function specifier label, e.g. 'wrt:', 'jvp:', or 'vjp:'}}
117
+ @differentiable ( , ) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
103
118
func bar( _ x: Float , _: Float ) -> Float {
104
119
return 1 + x
105
120
}
@@ -113,3 +128,18 @@ func bar(_ x: Float, _: Float) -> Float {
113
128
func bar< T : Numeric > ( _ x: T , _: T ) -> T {
114
129
return 1 + x
115
130
}
131
+
132
+ @differentiable ( wrt: x, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
133
+ func slope4( _ x: Float ) -> Float {
134
+ return 4 * x
135
+ }
136
+
137
+ @differentiable ( wrt: x, linear, vjp: const5) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
138
+ func slope5( _ x: Float ) -> Float {
139
+ return 5 * x
140
+ }
141
+
142
+ @differentiable ( wrt: x, vjp: const6, linear) // expected-error {{expected either 'wrt:' or a function specifier label, e.g. 'jvp:', or 'vjp:'}}
143
+ func slope5( _ x: Float ) -> Float {
144
+ return 6 * x
145
+ }
0 commit comments