Skip to content

Commit 99912b3

Browse files
committed
[AutoDiff] [API] Revamp @differentiable usages in stdlib.
- Use `FloatingPoint` rather than `BinaryFloatingPoint` to constrain differentiability. - Follows from: - #21673 - tensorflow/swift-bindings#11 - Use `@differentiable` where clauses to constrain differentiability of numeric operations. - The most common constraint is `where Scalar : FloatingPoint` because `Tensor` conditionally conforms to `Differentiable where Scalar : FloatingPoint`. Todos: - Make more `Tensor` operations differentiable. - This includes reduction and broadcasting ops. - This is enabled by `@differentiable` where clause type-checking. - Use VJP functions instead of adjoint functions. - I would prefer that this be done in a separate patch, after this patch adds the correct `@differentiable` where clauses. - Add tests for newly `@differentiable` `Tensor` operations.
1 parent e15142c commit 99912b3

File tree

7 files changed

+108
-112
lines changed

7 files changed

+108
-112
lines changed

stdlib/public/TensorFlow/CompositeMath.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,22 @@
1919
/// Computes `sigmoid` of the specified tensor element-wise.
2020
/// Specifically, computes `1 / (1 + exp(-x))`.
2121
@inlinable @inline(__always)
22-
public func sigmoid<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
22+
public func sigmoid<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
2323
return 1 / (1 + exp(-x))
2424
}
2525

2626
/// Computes `relu` of the specified tensor element-wise.
2727
/// Specifically, computes `max(0, x)`.
2828
@inlinable @inline(__always)
2929
@differentiable(adjoint: _adjointRelu(_:_:_:))
30-
public func relu<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
30+
public func relu<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
3131
return max(0, x)
3232
}
3333

3434
/// Computes the softmax of the specified tensor element-wise.
3535
/// Specifically, computes `exp(x) / exp(x).sum()`.
3636
@inlinable @inline(__always)
37-
public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
37+
public func softmax<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
3838
let expx = exp(x)
3939
let sum = expx.sum()
4040
return expx / sum
@@ -43,7 +43,7 @@ public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
4343
/// Computes the softmax of the specified tensor along the specified axis.
4444
/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`.
4545
@inlinable @inline(__always)
46-
public func softmax<T : BinaryFloatingPoint>(
46+
public func softmax<T : FloatingPoint>(
4747
_ x: Tensor<T>, alongAxis axis: Int32
4848
) -> Tensor<T> {
4949
let expx = exp(x)

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@
3636
// TODO:
3737
// - Add gradients for more ops ('sum', 'mean', etc).
3838
// - Fix gradients for broadcasting ops (need to perform reduction).
39-
// - When the trailing 'where' clause in @differentiable is properly
40-
// type-checked, define constraints on BinaryFloatingPoint in original
41-
// declarations and define adjoints on BinaryFloatingPoint.
4239
//
4340
// FIXME:
4441
// - Handle scalar broadcasting.
@@ -49,7 +46,7 @@
4946
// Elementwise binary
5047
//===----------------------------------------------------------------------===//
5148

52-
extension Tensor where Scalar : Numeric {
49+
extension Tensor where Scalar : FloatingPoint {
5350
@inlinable
5451
static func _adjointAdd(
5552
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor, _ y: Tensor
@@ -84,7 +81,7 @@ extension Tensor where Scalar : Numeric {
8481
}
8582

8683
@inlinable
87-
func _adjointMinMax<T : Numeric & Comparable>(
84+
func _adjointMinMax<T : FloatingPoint>(
8885
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
8986
) -> (Tensor<T>, Tensor<T>) {
9087
let denom = 1 + Tensor<T>(x .== y)
@@ -94,7 +91,7 @@ func _adjointMinMax<T : Numeric & Comparable>(
9491
}
9592

9693
@inlinable
97-
func _adjointPow<T : BinaryFloatingPoint>(
94+
func _adjointPow<T : FloatingPoint>(
9895
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
9996
) -> (Tensor<T>, Tensor<T>) {
10097
return ((seed * y * pow(x, y-1)).unbroadcast(like: x),
@@ -105,7 +102,7 @@ func _adjointPow<T : BinaryFloatingPoint>(
105102
// Elementwise unary
106103
//===----------------------------------------------------------------------===//
107104

108-
extension Tensor where Scalar : SignedNumeric {
105+
extension Tensor where Scalar : FloatingPoint {
109106
@inlinable
110107
static func _adjointNegate(
111108
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor
@@ -115,90 +112,90 @@ extension Tensor where Scalar : SignedNumeric {
115112
}
116113

117114
@inlinable
118-
func _adjointLog<T : BinaryFloatingPoint>(
115+
func _adjointLog<T : FloatingPoint>(
119116
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
120117
) -> Tensor<T> {
121118
return seed / x
122119
}
123120

124121
@inlinable
125-
func _adjointSin<T : BinaryFloatingPoint>(
122+
func _adjointSin<T : FloatingPoint>(
126123
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
127124
) -> Tensor<T> {
128125
return seed * cos(x)
129126
}
130127

131128
@inlinable
132-
func _adjointCos<T : BinaryFloatingPoint>(
129+
func _adjointCos<T : FloatingPoint>(
133130
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
134131
) -> Tensor<T> {
135132
return -seed * sin(x)
136133
}
137134

138135
@inlinable
139-
func _adjointTan<T : BinaryFloatingPoint>(
136+
func _adjointTan<T : FloatingPoint>(
140137
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
141138
) -> Tensor<T> {
142139
return seed * (1 + originalValue.squared())
143140
}
144141

145142
@inlinable
146-
func _adjointSinh<T : BinaryFloatingPoint>(
143+
func _adjointSinh<T : FloatingPoint>(
147144
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
148145
) -> Tensor<T> {
149146
return seed * cosh(x)
150147
}
151148

152149
@inlinable
153-
func _adjointCosh<T : BinaryFloatingPoint>(
150+
func _adjointCosh<T : FloatingPoint>(
154151
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
155152
) -> Tensor<T> {
156153
return seed * sinh(x)
157154
}
158155

159156
@inlinable
160-
func _adjointTanh<T : BinaryFloatingPoint>(
157+
func _adjointTanh<T : FloatingPoint>(
161158
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
162159
) -> Tensor<T> {
163160
return seed * (1 - originalValue.squared())
164161
}
165162

166163
@inlinable
167-
func _adjointExp<T : BinaryFloatingPoint>(
164+
func _adjointExp<T : FloatingPoint>(
168165
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
169166
) -> Tensor<T> {
170167
return originalValue * seed
171168
}
172169

173170
@inlinable
174-
func _adjointCeil<T : BinaryFloatingPoint>(
171+
func _adjointCeil<T : FloatingPoint>(
175172
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
176173
) -> Tensor<T> {
177174
return Tensor(0).broadcast(like: x)
178175
}
179176

180177
@inlinable
181-
func _adjointFloor<T : BinaryFloatingPoint>(
178+
func _adjointFloor<T : FloatingPoint>(
182179
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
183180
) -> Tensor<T> {
184181
return Tensor(0).broadcast(like: x)
185182
}
186183

187184
@inlinable
188-
func _adjointSqrt<T : BinaryFloatingPoint>(
185+
func _adjointSqrt<T : FloatingPoint>(
189186
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
190187
) -> Tensor<T> {
191188
return seed / (2 * originalValue)
192189
}
193190

194191
@inlinable
195-
func _adjointRsqrt<T : BinaryFloatingPoint>(
192+
func _adjointRsqrt<T : FloatingPoint>(
196193
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
197194
) -> Tensor<T> {
198195
return -seed / 2 * pow(originalValue, 3)
199196
}
200197

201-
func _adjointSquared<T : BinaryFloatingPoint>(
198+
func _adjointSquared<T : FloatingPoint>(
202199
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
203200
) -> Tensor<T> {
204201
return 2 * x * seed
@@ -209,7 +206,7 @@ func _adjointSquared<T : BinaryFloatingPoint>(
209206
//===----------------------------------------------------------------------===//
210207

211208
@inlinable
212-
func _adjointMatmul<Scalar : Numeric>(
209+
func _adjointMatmul<Scalar : FloatingPoint>(
213210
_ seed: Tensor<Scalar>, _ originalValue: Tensor<Scalar>,
214211
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>
215212
) -> (Tensor<Scalar>, Tensor<Scalar>) {
@@ -220,16 +217,14 @@ func _adjointMatmul<Scalar : Numeric>(
220217
// TODO: We have to define a custom adjoint on • because AD can't yet
221218
// differentiate generic methods. After AD can differentiate generic methods,
222219
// remove the custom adjoint.
223-
extension Tensor where Scalar : Numeric {
220+
extension Tensor where Scalar : FloatingPoint {
224221
@inlinable
225222
static func _adjointMatmulOperator(seed: Tensor, originalValue: Tensor,
226223
lhs: Tensor, rhs: Tensor)
227224
-> (Tensor, Tensor) {
228225
return _adjointMatmul(seed, originalValue, lhs, rhs)
229226
}
230-
}
231227

232-
extension Tensor {
233228
@inlinable
234229
func _adjointTransposed(
235230
_ seed: Tensor, _ originalValue: Tensor, _ permutations: Tensor<Int32>
@@ -243,7 +238,7 @@ extension Tensor {
243238
// Shape transformations
244239
//===----------------------------------------------------------------------===//
245240

246-
extension Tensor {
241+
extension Tensor where Scalar : FloatingPoint {
247242
@inlinable
248243
func _adjointReshaped(
249244
seed: Tensor, originalValue: Tensor, toShape newShape: Tensor<Int32>
@@ -265,9 +260,8 @@ extension Tensor {
265260
// Normalization
266261
//===----------------------------------------------------------------------===//
267262

268-
extension Tensor where Scalar : BinaryFloatingPoint,
269-
Scalar : Differentiable,
270-
Scalar.CotangentVector == Scalar {
263+
extension Tensor where Scalar : BinaryFloatingPoint & Differentiable,
264+
Scalar == Scalar.CotangentVector {
271265
// TODO: Verify that these calculations are correct.
272266
@inlinable
273267
func _adjointBatchNormalized(
@@ -304,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint,
304298
// Convolution and pooling
305299
//===----------------------------------------------------------------------===//
306300

307-
extension Tensor where Scalar : BinaryFloatingPoint {
301+
extension Tensor where Scalar : FloatingPoint {
308302
/// TensorFlow builtin conv2d gradient helper for the input.
309303
@inlinable
310304
@differentiable(
@@ -448,7 +442,7 @@ extension Tensor where Scalar : BinaryFloatingPoint {
448442
//===----------------------------------------------------------------------===//
449443

450444
@inlinable
451-
func _adjointRelu<T : BinaryFloatingPoint>(
445+
func _adjointRelu<T : FloatingPoint>(
452446
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
453447
) -> Tensor<T> {
454448
return Tensor(x .> 0) * seed

0 commit comments

Comments
 (0)