46
46
// Elementwise binary
47
47
//===----------------------------------------------------------------------===//
48
48
49
- extension Tensor where Scalar : FloatingPoint {
49
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
50
50
@inlinable
51
51
static func _adjointAdd(
52
52
_ seed: Tensor , _ originalValue: Tensor , _ x: Tensor , _ y: Tensor
@@ -81,7 +81,7 @@ extension Tensor where Scalar : FloatingPoint {
81
81
}
82
82
83
83
@inlinable
84
- func _adjointMinMax< T : FloatingPoint > (
84
+ func _adjointMinMax< T : Differentiable & FloatingPoint > (
85
85
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T > , _ y: Tensor < T >
86
86
) -> ( Tensor < T > , Tensor < T > ) {
87
87
let denom = 1 + Tensor < T > ( x .== y)
@@ -91,7 +91,7 @@ func _adjointMinMax<T : FloatingPoint>(
91
91
}
92
92
93
93
@inlinable
94
- func _adjointPow< T : FloatingPoint > (
94
+ func _adjointPow< T : Differentiable & FloatingPoint > (
95
95
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T > , _ y: Tensor < T >
96
96
) -> ( Tensor < T > , Tensor < T > ) {
97
97
return ( ( seed * y * pow( x, y- 1 ) ) . unbroadcast ( like: x) ,
@@ -102,7 +102,7 @@ func _adjointPow<T : FloatingPoint>(
102
102
// Elementwise unary
103
103
//===----------------------------------------------------------------------===//
104
104
105
- extension Tensor where Scalar : FloatingPoint {
105
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
106
106
@inlinable
107
107
static func _adjointNegate(
108
108
_ seed: Tensor , _ originalValue: Tensor , _ x: Tensor
@@ -112,90 +112,90 @@ extension Tensor where Scalar : FloatingPoint {
112
112
}
113
113
114
114
@inlinable
115
- func _adjointLog< T : FloatingPoint > (
115
+ func _adjointLog< T : Differentiable & FloatingPoint > (
116
116
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
117
117
) -> Tensor < T > {
118
118
return seed / x
119
119
}
120
120
121
121
@inlinable
122
- func _adjointSin< T : FloatingPoint > (
122
+ func _adjointSin< T : Differentiable & FloatingPoint > (
123
123
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
124
124
) -> Tensor < T > {
125
125
return seed * cos( x)
126
126
}
127
127
128
128
@inlinable
129
- func _adjointCos< T : FloatingPoint > (
129
+ func _adjointCos< T : Differentiable & FloatingPoint > (
130
130
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
131
131
) -> Tensor < T > {
132
132
return - seed * sin( x)
133
133
}
134
134
135
135
@inlinable
136
- func _adjointTan< T : FloatingPoint > (
136
+ func _adjointTan< T : Differentiable & FloatingPoint > (
137
137
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
138
138
) -> Tensor < T > {
139
139
return seed * ( 1 + originalValue. squared ( ) )
140
140
}
141
141
142
142
@inlinable
143
- func _adjointSinh< T : FloatingPoint > (
143
+ func _adjointSinh< T : Differentiable & FloatingPoint > (
144
144
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
145
145
) -> Tensor < T > {
146
146
return seed * cosh( x)
147
147
}
148
148
149
149
@inlinable
150
- func _adjointCosh< T : FloatingPoint > (
150
+ func _adjointCosh< T : Differentiable & FloatingPoint > (
151
151
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
152
152
) -> Tensor < T > {
153
153
return seed * sinh( x)
154
154
}
155
155
156
156
@inlinable
157
- func _adjointTanh< T : FloatingPoint > (
157
+ func _adjointTanh< T : Differentiable & FloatingPoint > (
158
158
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
159
159
) -> Tensor < T > {
160
160
return seed * ( 1 - originalValue. squared ( ) )
161
161
}
162
162
163
163
@inlinable
164
- func _adjointExp< T : FloatingPoint > (
164
+ func _adjointExp< T : Differentiable & FloatingPoint > (
165
165
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
166
166
) -> Tensor < T > {
167
167
return originalValue * seed
168
168
}
169
169
170
170
@inlinable
171
- func _adjointCeil< T : FloatingPoint > (
171
+ func _adjointCeil< T : Differentiable & FloatingPoint > (
172
172
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
173
173
) -> Tensor < T > {
174
174
return Tensor ( 0 ) . broadcast ( like: x)
175
175
}
176
176
177
177
@inlinable
178
- func _adjointFloor< T : FloatingPoint > (
178
+ func _adjointFloor< T : Differentiable & FloatingPoint > (
179
179
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
180
180
) -> Tensor < T > {
181
181
return Tensor ( 0 ) . broadcast ( like: x)
182
182
}
183
183
184
184
@inlinable
185
- func _adjointSqrt< T : FloatingPoint > (
185
+ func _adjointSqrt< T : Differentiable & FloatingPoint > (
186
186
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
187
187
) -> Tensor < T > {
188
188
return seed / ( 2 * originalValue)
189
189
}
190
190
191
191
@inlinable
192
- func _adjointRsqrt< T : FloatingPoint > (
192
+ func _adjointRsqrt< T : Differentiable & FloatingPoint > (
193
193
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
194
194
) -> Tensor < T > {
195
195
return - seed / 2 * pow( originalValue, 3 )
196
196
}
197
197
198
- func _adjointSquared< T : FloatingPoint > (
198
+ func _adjointSquared< T : Differentiable & FloatingPoint > (
199
199
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
200
200
) -> Tensor < T > {
201
201
return 2 * x * seed
@@ -206,7 +206,7 @@ func _adjointSquared<T : FloatingPoint>(
206
206
//===----------------------------------------------------------------------===//
207
207
208
208
@inlinable
209
- func _adjointMatmul< Scalar : FloatingPoint > (
209
+ func _adjointMatmul< Scalar : Differentiable & FloatingPoint > (
210
210
_ seed: Tensor < Scalar > , _ originalValue: Tensor < Scalar > ,
211
211
_ left: Tensor < Scalar > , _ right: Tensor < Scalar >
212
212
) -> ( Tensor < Scalar > , Tensor < Scalar > ) {
@@ -217,7 +217,7 @@ func _adjointMatmul<Scalar : FloatingPoint>(
217
217
// TODO: We have to define a custom adjoint on • because AD can't yet
218
218
// differentiate generic methods. After AD can differentiate generic methods,
219
219
// remove the custom adjoint.
220
- extension Tensor where Scalar : FloatingPoint {
220
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
221
221
@inlinable
222
222
static func _adjointMatmulOperator( seed: Tensor , originalValue: Tensor ,
223
223
lhs: Tensor , rhs: Tensor )
@@ -238,7 +238,7 @@ extension Tensor where Scalar : FloatingPoint {
238
238
// Shape transformations
239
239
//===----------------------------------------------------------------------===//
240
240
241
- extension Tensor where Scalar : FloatingPoint {
241
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
242
242
@inlinable
243
243
func _adjointReshaped(
244
244
seed: Tensor , originalValue: Tensor , toShape newShape: Tensor < Int32 >
@@ -298,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint & Differentiable,
298
298
// Convolution and pooling
299
299
//===----------------------------------------------------------------------===//
300
300
301
- extension Tensor where Scalar : FloatingPoint {
301
+ extension Tensor where Scalar : Differentiable & FloatingPoint {
302
302
/// TensorFlow builtin conv2d gradient helper for the input.
303
303
@inlinable
304
304
@differentiable (
@@ -442,7 +442,7 @@ extension Tensor where Scalar : FloatingPoint {
442
442
//===----------------------------------------------------------------------===//
443
443
444
444
@inlinable
445
- func _adjointRelu< T : FloatingPoint > (
445
+ func _adjointRelu< T : Differentiable & FloatingPoint > (
446
446
_ seed: Tensor < T > , _ originalValue: Tensor < T > , _ x: Tensor < T >
447
447
) -> Tensor < T > {
448
448
return Tensor ( x .> 0 ) * seed
0 commit comments