36
36
// TODO:
37
37
// - Add gradients for more ops ('sum', 'mean', etc).
38
38
// - Fix gradients for broadcasting ops (need to perform reduction).
39
- // - When the trailing 'where' clause in @differentiable is properly
40
- // type-checked, define constraints on BinaryFloatingPoint in original
41
- // declarations and define adjoints on BinaryFloatingPoint.
42
39
//
43
40
// FIXME:
44
41
// - Handle scalar broadcasting.
49
46
// Elementwise binary
50
47
//===----------------------------------------------------------------------===//
51
48
52
- extension Tensor where Scalar : Numeric {
49
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
53
50
@inlinable
54
51
static func _adjointAdd(
55
52
_ seed: Tensor , _ originalValue: Tensor , _ x: Tensor , _ y: Tensor
@@ -84,7 +81,7 @@ extension Tensor where Scalar : Numeric {
84
81
}
85
82
86
83
@inlinable
87
- func _adjointMinMax< T : Numeric & Comparable > (
84
+ func _adjointMinMax< T : Differentiable & FloatingPoint > (
88
85
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T > , _ y: Tensor < T >
89
86
) -> ( Tensor < T > , Tensor < T > ) {
90
87
let denom = 1 + Tensor < T > ( x .== y)
@@ -94,7 +91,7 @@ func _adjointMinMax<T : Numeric & Comparable>(
94
91
}
95
92
96
93
@inlinable
97
- func _adjointPow< T : BinaryFloatingPoint > (
94
+ func _adjointPow< T : Differentiable & FloatingPoint > (
98
95
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T > , _ y: Tensor < T >
99
96
) -> ( Tensor < T > , Tensor < T > ) {
100
97
return ( ( seed * y * pow( x, y- 1 ) ) . unbroadcast ( like: x) ,
@@ -105,7 +102,7 @@ func _adjointPow<T : BinaryFloatingPoint>(
105
102
// Elementwise unary
106
103
//===----------------------------------------------------------------------===//
107
104
108
- extension Tensor where Scalar : SignedNumeric {
105
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
109
106
@inlinable
110
107
static func _adjointNegate(
111
108
_ seed: Tensor , _ originalValue: Tensor , _ x: Tensor
@@ -115,90 +112,90 @@ extension Tensor where Scalar : SignedNumeric {
115
112
}
116
113
117
114
@inlinable
118
- func _adjointLog< T : BinaryFloatingPoint > (
115
+ func _adjointLog< T : Differentiable & FloatingPoint > (
119
116
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
120
117
) -> Tensor < T > {
121
118
return seed / x
122
119
}
123
120
124
121
@inlinable
125
- func _adjointSin< T : BinaryFloatingPoint > (
122
+ func _adjointSin< T : Differentiable & FloatingPoint > (
126
123
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
127
124
) -> Tensor < T > {
128
125
return seed * cos( x)
129
126
}
130
127
131
128
@inlinable
132
- func _adjointCos< T : BinaryFloatingPoint > (
129
+ func _adjointCos< T : Differentiable & FloatingPoint > (
133
130
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
134
131
) -> Tensor < T > {
135
132
return - seed * sin( x)
136
133
}
137
134
138
135
@inlinable
139
- func _adjointTan< T : BinaryFloatingPoint > (
136
+ func _adjointTan< T : Differentiable & FloatingPoint > (
140
137
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
141
138
) -> Tensor < T > {
142
139
return seed * ( 1 + originalValue. squared ( ) )
143
140
}
144
141
145
142
@inlinable
146
- func _adjointSinh< T : BinaryFloatingPoint > (
143
+ func _adjointSinh< T : Differentiable & FloatingPoint > (
147
144
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
148
145
) -> Tensor < T > {
149
146
return seed * cosh( x)
150
147
}
151
148
152
149
@inlinable
153
- func _adjointCosh< T : BinaryFloatingPoint > (
150
+ func _adjointCosh< T : Differentiable & FloatingPoint > (
154
151
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
155
152
) -> Tensor < T > {
156
153
return seed * sinh( x)
157
154
}
158
155
159
156
@inlinable
160
- func _adjointTanh< T : BinaryFloatingPoint > (
157
+ func _adjointTanh< T : Differentiable & FloatingPoint > (
161
158
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
162
159
) -> Tensor < T > {
163
160
return seed * ( 1 - originalValue. squared ( ) )
164
161
}
165
162
166
163
@inlinable
167
- func _adjointExp< T : BinaryFloatingPoint > (
164
+ func _adjointExp< T : Differentiable & FloatingPoint > (
168
165
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
169
166
) -> Tensor < T > {
170
167
return originalValue * seed
171
168
}
172
169
173
170
@inlinable
174
- func _adjointCeil< T : BinaryFloatingPoint > (
171
+ func _adjointCeil< T : Differentiable & FloatingPoint > (
175
172
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
176
173
) -> Tensor < T > {
177
174
return Tensor ( 0 ) . broadcast ( like: x)
178
175
}
179
176
180
177
@inlinable
181
- func _adjointFloor< T : BinaryFloatingPoint > (
178
+ func _adjointFloor< T : Differentiable & FloatingPoint > (
182
179
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
183
180
) -> Tensor < T > {
184
181
return Tensor ( 0 ) . broadcast ( like: x)
185
182
}
186
183
187
184
@inlinable
188
- func _adjointSqrt< T : BinaryFloatingPoint > (
185
+ func _adjointSqrt< T : Differentiable & FloatingPoint > (
189
186
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
190
187
) -> Tensor < T > {
191
188
return seed / ( 2 * originalValue)
192
189
}
193
190
194
191
@inlinable
195
- func _adjointRsqrt< T : BinaryFloatingPoint > (
192
+ func _adjointRsqrt< T : Differentiable & FloatingPoint > (
196
193
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
197
194
) -> Tensor < T > {
198
195
return - seed / 2 * pow( originalValue, 3 )
199
196
}
200
197
201
- func _adjointSquared< T : BinaryFloatingPoint > (
198
+ func _adjointSquared< T : Differentiable & FloatingPoint > (
202
199
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
203
200
) -> Tensor < T > {
204
201
return 2 * x * seed
@@ -209,7 +206,7 @@ func _adjointSquared<T : BinaryFloatingPoint>(
209
206
//===----------------------------------------------------------------------===//
210
207
211
208
@inlinable
212
- func _adjointMatmul< Scalar : Numeric > (
209
+ func _adjointMatmul< Scalar : Differentiable & FloatingPoint > (
213
210
_ seed: Tensor < Scalar > , _ originalValue: Tensor < Scalar > ,
214
211
_ left: Tensor < Scalar > , _ right: Tensor < Scalar >
215
212
) -> ( Tensor < Scalar > , Tensor < Scalar > ) {
@@ -220,16 +217,14 @@ func _adjointMatmul<Scalar : Numeric>(
220
217
// TODO: We have to define a custom adjoint on • because AD can't yet
221
218
// differentiate generic methods. After AD can differentiate generic methods,
222
219
// remove the custom adjoint.
223
- extension Tensor where Scalar : Numeric {
220
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
224
221
@inlinable
225
222
static func _adjointMatmulOperator( seed: Tensor , originalValue: Tensor ,
226
223
lhs: Tensor , rhs: Tensor )
227
224
-> ( Tensor , Tensor ) {
228
225
return _adjointMatmul ( seed, originalValue, lhs, rhs)
229
226
}
230
- }
231
227
232
- extension Tensor {
233
228
@inlinable
234
229
func _adjointTransposed(
235
230
_ seed: Tensor , _ originalValue: Tensor , _ permutations: Tensor < Int32 >
@@ -243,7 +238,7 @@ extension Tensor {
243
238
// Shape transformations
244
239
//===----------------------------------------------------------------------===//
245
240
246
- extension Tensor {
241
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
247
242
@inlinable
248
243
func _adjointReshaped(
249
244
seed: Tensor , originalValue: Tensor , toShape newShape: Tensor < Int32 >
@@ -265,9 +260,8 @@ extension Tensor {
265
260
// Normalization
266
261
//===----------------------------------------------------------------------===//
267
262
268
- extension Tensor where Scalar : BinaryFloatingPoint ,
269
- Scalar : Differentiable ,
270
- Scalar. CotangentVector == Scalar {
263
+ extension Tensor where Scalar : BinaryFloatingPoint & Differentiable ,
264
+ Scalar == Scalar . CotangentVector {
271
265
// TODO: Verify that these calculations are correct.
272
266
@inlinable
273
267
func _adjointBatchNormalized(
@@ -304,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint,
304
298
// Convolution and pooling
305
299
//===----------------------------------------------------------------------===//
306
300
307
- extension Tensor where Scalar : BinaryFloatingPoint {
301
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
308
302
/// TensorFlow builtin conv2d gradient helper for the input.
309
303
@inlinable
310
304
@differentiable (
@@ -448,7 +442,7 @@ extension Tensor where Scalar : BinaryFloatingPoint {
448
442
//===----------------------------------------------------------------------===//
449
443
450
444
@inlinable
451
- func _adjointRelu< T : BinaryFloatingPoint > (
445
+ func _adjointRelu< T : Differentiable & FloatingPoint > (
452
446
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
453
447
) -> Tensor < T > {
454
448
return Tensor ( x .> 0 ) * seed
0 commit comments