@@ -4,12 +4,15 @@ import StdlibUnittest
4
4
5
5
var ProtocolRequirementAutodiffTests = TestSuite ( " ProtocolRequirementAutodiff " )
6
6
7
+ // MARK: - Func requirements.
8
+
7
9
protocol DiffReq : Differentiable {
8
10
@differentiable ( wrt: ( self , x) )
9
11
func f( _ x: Float ) -> Float
10
12
}
11
13
12
14
extension DiffReq where TangentVector : AdditiveArithmetic {
15
+ @inline ( never) // Prevent specialization, to test all witness code.
13
16
func gradF( at x: Float ) -> ( Self . TangentVector , Float ) {
14
17
return ( valueWithPullback ( at: x) { s, x in s. f ( x) } ) . 1 ( 1 )
15
18
}
@@ -53,7 +56,76 @@ extension Quadratic : VectorProtocol {
53
56
}
54
57
}
55
58
56
- // Test witness method SIL type computation.
59
+ ProtocolRequirementAutodiffTests . test ( " func " ) {
60
+ expectEqual ( ( Quadratic ( 0 , 0 , 1 ) , 12 ) , Quadratic ( 11 , 12 , 13 ) . gradF ( at: 0 ) )
61
+ expectEqual ( ( Quadratic ( 1 , 1 , 1 ) , 2 * 11 + 12 ) ,
62
+ Quadratic ( 11 , 12 , 13 ) . gradF ( at: 1 ) )
63
+ expectEqual ( ( Quadratic ( 4 , 2 , 1 ) , 2 * 11 * 2 + 12 ) ,
64
+ Quadratic ( 11 , 12 , 13 ) . gradF ( at: 2 ) )
65
+ }
66
+
67
+ // MARK: Constructor, accessor, and subscript requirements.
68
+
69
+ protocol FunctionsOfX : Differentiable {
70
+ @differentiable
71
+ init ( x: Float )
72
+
73
+ @differentiable
74
+ var x : Float { get }
75
+
76
+ @differentiable
77
+ var y : Float { get }
78
+
79
+ @differentiable
80
+ var z : Float { get }
81
+
82
+ @differentiable
83
+ subscript( ) -> Float { get }
84
+ }
85
+
86
+ struct TestFunctionsOfX : FunctionsOfX {
87
+ @differentiable
88
+ init ( x: Float ) {
89
+ self . x = x
90
+ self . y = x * x
91
+ }
92
+
93
+ /// x = x
94
+ var x : Float
95
+
96
+ /// y = x * x
97
+ var y : Float
98
+
99
+ /// z = x * x + x
100
+ var z : Float {
101
+ return y + x
102
+ }
103
+
104
+ @differentiable
105
+ subscript( ) -> Float {
106
+ return z
107
+ }
108
+ }
109
+
110
+ @inline ( never) // Prevent specialization, to test all witness code.
111
+ func derivatives< F: FunctionsOfX > ( at x: Float , in: F . Type )
112
+ -> ( Float , Float , Float , Float )
113
+ {
114
+ let dxdx = gradient ( at: x) { x in F ( x: x) . x }
115
+ let dydx = gradient ( at: x) { x in F ( x: x) . y }
116
+ let dzdx = gradient ( at: x) { x in F ( x: x) . z }
117
+ let dsubscriptdx = gradient ( at: x) { x in F ( x: x) [ ] }
118
+ return ( dxdx, dydx, dzdx, dsubscriptdx)
119
+ }
120
+
121
+ ProtocolRequirementAutodiffTests . test ( " constructor, accessor, subscript " ) {
122
+ expectEqual (
123
+ derivatives ( at: 2.0 , in: TestFunctionsOfX . self) ,
124
+ ( 1.0 , 4.0 , 5.0 , 5.0 ) )
125
+ }
126
+
127
+ // MARK: - Test witness method SIL type computation.
128
+
57
129
protocol P : Differentiable {
58
130
@differentiable ( wrt: ( x, y) )
59
131
func foo( _ x: Float , _ y: Double ) -> Float
@@ -65,12 +137,4 @@ struct S : P {
65
137
}
66
138
}
67
139
68
- ProtocolRequirementAutodiffTests . test ( " Trivial " ) {
69
- expectEqual ( ( Quadratic ( 0 , 0 , 1 ) , 12 ) , Quadratic ( 11 , 12 , 13 ) . gradF ( at: 0 ) )
70
- expectEqual ( ( Quadratic ( 1 , 1 , 1 ) , 2 * 11 + 12 ) ,
71
- Quadratic ( 11 , 12 , 13 ) . gradF ( at: 1 ) )
72
- expectEqual ( ( Quadratic ( 4 , 2 , 1 ) , 2 * 11 * 2 + 12 ) ,
73
- Quadratic ( 11 , 12 , 13 ) . gradF ( at: 2 ) )
74
- }
75
-
76
140
runAllTests ( )
0 commit comments