@@ -62,7 +62,6 @@ extension Float: HasFloat {
62
62
init ( float: Float ) { self = float }
63
63
}
64
64
65
- #if REQUIRES_SR14042
66
65
ReabstractionE2ETests . test ( " diff param generic => concrete " ) {
67
66
func inner< T: HasFloat > ( x: T ) -> Float {
68
67
7 * x. float * x. float
@@ -71,7 +70,6 @@ ReabstractionE2ETests.test("diff param generic => concrete") {
71
70
expectEqual ( Float ( 7 * 3 * 3 ) , transformed ( 3 ) )
72
71
expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 , of: transformed) )
73
72
}
74
- #endif
75
73
76
74
ReabstractionE2ETests . test ( " nondiff param generic => concrete " ) {
77
75
func inner< T: HasFloat > ( x: Float , y: T ) -> Float {
@@ -82,7 +80,6 @@ ReabstractionE2ETests.test("nondiff param generic => concrete") {
82
80
expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 ) { transformed ( $0, 10 ) } )
83
81
}
84
82
85
- #if REQUIRES_SR14042
86
83
ReabstractionE2ETests . test ( " diff param and nondiff param generic => concrete " ) {
87
84
func inner< T: HasFloat > ( x: T , y: T ) -> Float {
88
85
7 * x. float * x. float + y. float
@@ -91,9 +88,7 @@ ReabstractionE2ETests.test("diff param and nondiff param generic => concrete") {
91
88
expectEqual ( Float ( 7 * 3 * 3 + 10 ) , transformed ( 3 , 10 ) )
92
89
expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 ) { transformed ( $0, 10 ) } )
93
90
}
94
- #endif
95
91
96
- #if REQUIRES_SR14042
97
92
ReabstractionE2ETests . test ( " result generic => concrete " ) {
98
93
func inner< T: HasFloat > ( x: Float ) -> T {
99
94
T ( float: 7 * x * x)
@@ -102,7 +97,6 @@ ReabstractionE2ETests.test("result generic => concrete") {
102
97
expectEqual ( Float ( 7 * 3 * 3 ) , transformed ( 3 ) )
103
98
expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 , of: transformed) )
104
99
}
105
- #endif
106
100
107
101
ReabstractionE2ETests . test ( " diff param concrete => generic => concrete " ) {
108
102
typealias FnTy < T: Differentiable > = @differentiable ( reverse) ( T ) -> Float
@@ -152,21 +146,19 @@ ReabstractionE2ETests.test("@differentiable(reverse) function => opaque generic
152
146
func id< T> ( _ t: T ) -> T { t }
153
147
let inner : @differentiable ( reverse) ( Float ) -> Float = { 7 * $0 * $0 }
154
148
155
- // TODO(TF-1122): Actually using `id` causes a segfault at runtime.
156
- // let transformed = id(inner)
157
- // expectEqual(Float(7 * 3 * 3), transformed(3))
158
- // expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: id(inner)))
149
+ let transformed = id ( inner)
150
+ expectEqual ( Float ( 7 * 3 * 3 ) , transformed ( 3 ) )
151
+ expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 , of: id ( inner) ) )
159
152
}
160
153
161
154
ReabstractionE2ETests . test ( " @differentiable(reverse) function => opaque Any => concrete " ) {
162
155
func id( _ any: Any ) -> Any { any }
163
156
let inner : @differentiable ( reverse) ( Float ) -> Float = { 7 * $0 * $0 }
164
157
165
- // TODO(TF-1122): Actually using `id` causes a segfault at runtime.
166
- // let transformed = id(inner)
167
- // let casted = transformed as! @differentiable(reverse) (Float) -> Float
168
- // expectEqual(Float(7 * 3 * 3), casted(3))
169
- // expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: casted))
158
+ let transformed = id ( inner)
159
+ let casted = transformed as! @differentiable ( reverse) ( Float ) -> Float
160
+ expectEqual ( Float ( 7 * 3 * 3 ) , casted ( 3 ) )
161
+ expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 , of: casted) )
170
162
}
171
163
172
164
ReabstractionE2ETests . test ( " access @differentiable(reverse) function using KeyPath " ) {
@@ -176,10 +168,9 @@ ReabstractionE2ETests.test("access @differentiable(reverse) function using KeyPa
176
168
let container = Container ( f: { 7 * $0 * $0 } )
177
169
let kp = \Container . f
178
170
179
- // TODO(TF-1122): Actually using `kp` causes a segfault at runtime.
180
- // let extracted = container[keyPath: kp]
181
- // expectEqual(Float(7 * 3 * 3), extracted(3))
182
- // expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: extracted))
171
+ let extracted = container [ keyPath: kp]
172
+ expectEqual ( Float ( 7 * 3 * 3 ) , extracted ( 3 ) )
173
+ expectEqual ( Float ( 7 * 2 * 3 ) , gradient ( at: 3 , of: extracted) )
183
174
}
184
175
185
176
runAllTests ( )
0 commit comments