Skip to content

Commit 134500f

Browse files
authored
[API] [AD] Revamp @differentiable usages in stdlib. (#21732)
* [AutoDiff] [API] Revamp `@differentiable` usages in stdlib. - Use `FloatingPoint` rather than `BinaryFloatingPoint` to constrain differentiability. - Follows from: - #21673 - tensorflow/swift-bindings#11 - Use `@differentiable` where clauses to constrain differentiability of numeric operations. - The most common constraint is `where Scalar : FloatingPoint` because `Tensor` conditionally conforms to `Differentiable where Scalar : FloatingPoint`. - `Tensor` now conditionally conforms to `Differentiable` where `Scalar : Differentiable & FloatingPoint`. - Allow `@differentiable` where clause conformance requirements to protocol composition types. Todos: - Make more `Tensor` operations differentiable. - This includes reduction and broadcasting ops. - This is enabled by `@differentiable` where clause type-checking. - Use VJP functions instead of adjoint functions. - I would prefer that this be done in a separate patch, after this patch adds the correct `@differentiable` where clauses. - Add tests for newly `@differentiable` `Tensor` operations. * [AutoDiff] Make VJP applications use the correct substitution map. If a custom `@differentiable` attribute defines a VJP and where clause requirements, VJP applications should use a substitution map involving those requirements. Note: more related cases need to be handled, such as `@differentiable` attributes with where clause requirements but no VJP. These cases will be handled later.
1 parent 63b3d0b commit 134500f

File tree

9 files changed

+166
-131
lines changed

9 files changed

+166
-131
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,9 +2539,16 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
25392539
newArgs.push_back(getOpValue(origArg));
25402540
assert(newArgs.size() == numVJPParams);
25412541
// Apply the VJP.
2542-
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp,
2543-
ai->getSubstitutionMap(), newArgs,
2544-
ai->isNonThrowing());
2542+
auto substMap = ai->getSubstitutionMap();
2543+
if (auto vjpGenSig = vjpFnTy->getGenericSignature()) {
2544+
auto vjpSubstMap =
2545+
vjpGenSig->createGenericEnvironment()->getForwardingSubstitutionMap();
2546+
substMap = vjpSubstMap.subst(
2547+
[&](SubstitutableType *ty) { return Type(ty).subst(substMap); },
2548+
LookUpConformanceInModule(context.getModule().getSwiftModule()));
2549+
}
2550+
auto *vjpCall = getBuilder().createApply(ai->getLoc(), vjp, substMap,
2551+
newArgs, ai->isNonThrowing());
25452552
LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
25462553

25472554
// Get the VJP results (original results and pullback).

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,14 +2345,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23452345

23462346
// Conformance requirements are valid if:
23472347
// - The first type is a generic type parameter type.
2348-
// - The second type is a protocol type.
2348+
// - The second type is a protocol type or protocol composition type.
23492349
case RequirementKind::Conformance:
23502350
if (diagnoseDifferentiableAttrIndirectGenericType(
23512351
attr->getLocation(), req.getFirstType(),
23522352
reqRepr->getSubjectRepr()))
23532353
return false;
23542354

2355-
if (!req.getSecondType()->is<ProtocolType>()) {
2355+
if (!req.getSecondType()->is<ProtocolType>() &&
2356+
!req.getSecondType()->is<ProtocolCompositionType>()) {
23562357
TC.diagnose(attr->getLocation(),
23572358
diag::differentiable_attr_non_protocol_type_constraint_req)
23582359
.highlight(reqRepr->getSourceRange());

stdlib/public/TensorFlow/CompositeMath.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,28 @@
1919
/// Computes `sigmoid` of the specified tensor element-wise.
2020
/// Specifically, computes `1 / (1 + exp(-x))`.
2121
@inlinable @inline(__always)
22-
public func sigmoid<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
22+
public func sigmoid<T>(_ x: Tensor<T>) -> Tensor<T>
23+
where T : Differentiable & FloatingPoint
24+
{
2325
return 1 / (1 + exp(-x))
2426
}
2527

2628
/// Computes `relu` of the specified tensor element-wise.
2729
/// Specifically, computes `max(0, x)`.
2830
@inlinable @inline(__always)
2931
@differentiable(adjoint: _adjointRelu(_:_:_:))
30-
public func relu<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
32+
public func relu<T>(_ x: Tensor<T>) -> Tensor<T>
33+
where T : Differentiable & FloatingPoint
34+
{
3135
return max(0, x)
3236
}
3337

3438
/// Computes the softmax of the specified tensor element-wise.
3539
/// Specifically, computes `exp(x) / exp(x).sum()`.
3640
@inlinable @inline(__always)
37-
public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
41+
public func softmax<T>(_ x: Tensor<T>) -> Tensor<T>
42+
where T : Differentiable & FloatingPoint
43+
{
3844
let expx = exp(x)
3945
let sum = expx.sum()
4046
return expx / sum
@@ -43,7 +49,7 @@ public func softmax<T : BinaryFloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
4349
/// Computes the softmax of the specified tensor along the specified axis.
4450
/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`.
4551
@inlinable @inline(__always)
46-
public func softmax<T : BinaryFloatingPoint>(
52+
public func softmax<T : Differentiable & FloatingPoint>(
4753
_ x: Tensor<T>, alongAxis axis: Int32
4854
) -> Tensor<T> {
4955
let expx = exp(x)

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@
3636
// TODO:
3737
// - Add gradients for more ops ('sum', 'mean', etc).
3838
// - Fix gradients for broadcasting ops (need to perform reduction).
39-
// - When the trailing 'where' clause in @differentiable is properly
40-
// type-checked, define constraints on BinaryFloatingPoint in original
41-
// declarations and define adjoints on BinaryFloatingPoint.
4239
//
4340
// FIXME:
4441
// - Handle scalar broadcasting.
@@ -49,7 +46,7 @@
4946
// Elementwise binary
5047
//===----------------------------------------------------------------------===//
5148

52-
extension Tensor where Scalar : Numeric {
49+
extension Tensor where Scalar : Differentiable & FloatingPoint {
5350
@inlinable
5451
static func _adjointAdd(
5552
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor, _ y: Tensor
@@ -84,7 +81,7 @@ extension Tensor where Scalar : Numeric {
8481
}
8582

8683
@inlinable
87-
func _adjointMinMax<T : Numeric & Comparable>(
84+
func _adjointMinMax<T : Differentiable & FloatingPoint>(
8885
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
8986
) -> (Tensor<T>, Tensor<T>) {
9087
let denom = 1 + Tensor<T>(x .== y)
@@ -94,7 +91,7 @@ func _adjointMinMax<T : Numeric & Comparable>(
9491
}
9592

9693
@inlinable
97-
func _adjointPow<T : BinaryFloatingPoint>(
94+
func _adjointPow<T : Differentiable & FloatingPoint>(
9895
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
9996
) -> (Tensor<T>, Tensor<T>) {
10097
return ((seed * y * pow(x, y-1)).unbroadcast(like: x),
@@ -105,7 +102,7 @@ func _adjointPow<T : BinaryFloatingPoint>(
105102
// Elementwise unary
106103
//===----------------------------------------------------------------------===//
107104

108-
extension Tensor where Scalar : SignedNumeric {
105+
extension Tensor where Scalar : Differentiable & FloatingPoint {
109106
@inlinable
110107
static func _adjointNegate(
111108
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor
@@ -115,90 +112,90 @@ extension Tensor where Scalar : SignedNumeric {
115112
}
116113

117114
@inlinable
118-
func _adjointLog<T : BinaryFloatingPoint>(
115+
func _adjointLog<T : Differentiable & FloatingPoint>(
119116
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
120117
) -> Tensor<T> {
121118
return seed / x
122119
}
123120

124121
@inlinable
125-
func _adjointSin<T : BinaryFloatingPoint>(
122+
func _adjointSin<T : Differentiable & FloatingPoint>(
126123
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
127124
) -> Tensor<T> {
128125
return seed * cos(x)
129126
}
130127

131128
@inlinable
132-
func _adjointCos<T : BinaryFloatingPoint>(
129+
func _adjointCos<T : Differentiable & FloatingPoint>(
133130
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
134131
) -> Tensor<T> {
135132
return -seed * sin(x)
136133
}
137134

138135
@inlinable
139-
func _adjointTan<T : BinaryFloatingPoint>(
136+
func _adjointTan<T : Differentiable & FloatingPoint>(
140137
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
141138
) -> Tensor<T> {
142139
return seed * (1 + originalValue.squared())
143140
}
144141

145142
@inlinable
146-
func _adjointSinh<T : BinaryFloatingPoint>(
143+
func _adjointSinh<T : Differentiable & FloatingPoint>(
147144
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
148145
) -> Tensor<T> {
149146
return seed * cosh(x)
150147
}
151148

152149
@inlinable
153-
func _adjointCosh<T : BinaryFloatingPoint>(
150+
func _adjointCosh<T : Differentiable & FloatingPoint>(
154151
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
155152
) -> Tensor<T> {
156153
return seed * sinh(x)
157154
}
158155

159156
@inlinable
160-
func _adjointTanh<T : BinaryFloatingPoint>(
157+
func _adjointTanh<T : Differentiable & FloatingPoint>(
161158
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
162159
) -> Tensor<T> {
163160
return seed * (1 - originalValue.squared())
164161
}
165162

166163
@inlinable
167-
func _adjointExp<T : BinaryFloatingPoint>(
164+
func _adjointExp<T : Differentiable & FloatingPoint>(
168165
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
169166
) -> Tensor<T> {
170167
return originalValue * seed
171168
}
172169

173170
@inlinable
174-
func _adjointCeil<T : BinaryFloatingPoint>(
171+
func _adjointCeil<T : Differentiable & FloatingPoint>(
175172
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
176173
) -> Tensor<T> {
177174
return Tensor(0).broadcast(like: x)
178175
}
179176

180177
@inlinable
181-
func _adjointFloor<T : BinaryFloatingPoint>(
178+
func _adjointFloor<T : Differentiable & FloatingPoint>(
182179
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
183180
) -> Tensor<T> {
184181
return Tensor(0).broadcast(like: x)
185182
}
186183

187184
@inlinable
188-
func _adjointSqrt<T : BinaryFloatingPoint>(
185+
func _adjointSqrt<T : Differentiable & FloatingPoint>(
189186
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
190187
) -> Tensor<T> {
191188
return seed / (2 * originalValue)
192189
}
193190

194191
@inlinable
195-
func _adjointRsqrt<T : BinaryFloatingPoint>(
192+
func _adjointRsqrt<T : Differentiable & FloatingPoint>(
196193
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
197194
) -> Tensor<T> {
198195
return -seed / 2 * pow(originalValue, 3)
199196
}
200197

201-
func _adjointSquared<T : BinaryFloatingPoint>(
198+
func _adjointSquared<T : Differentiable & FloatingPoint>(
202199
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
203200
) -> Tensor<T> {
204201
return 2 * x * seed
@@ -209,7 +206,7 @@ func _adjointSquared<T : BinaryFloatingPoint>(
209206
//===----------------------------------------------------------------------===//
210207

211208
@inlinable
212-
func _adjointMatmul<Scalar : Numeric>(
209+
func _adjointMatmul<Scalar : Differentiable & FloatingPoint>(
213210
_ seed: Tensor<Scalar>, _ originalValue: Tensor<Scalar>,
214211
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>
215212
) -> (Tensor<Scalar>, Tensor<Scalar>) {
@@ -220,16 +217,14 @@ func _adjointMatmul<Scalar : Numeric>(
220217
// TODO: We have to define a custom adjoint on • because AD can't yet
221218
// differentiate generic methods. After AD can differentiate generic methods,
222219
// remove the custom adjoint.
223-
extension Tensor where Scalar : Numeric {
220+
extension Tensor where Scalar : Differentiable & FloatingPoint {
224221
@inlinable
225222
static func _adjointMatmulOperator(seed: Tensor, originalValue: Tensor,
226223
lhs: Tensor, rhs: Tensor)
227224
-> (Tensor, Tensor) {
228225
return _adjointMatmul(seed, originalValue, lhs, rhs)
229226
}
230-
}
231227

232-
extension Tensor {
233228
@inlinable
234229
func _adjointTransposed(
235230
_ seed: Tensor, _ originalValue: Tensor, _ permutations: Tensor<Int32>
@@ -243,7 +238,7 @@ extension Tensor {
243238
// Shape transformations
244239
//===----------------------------------------------------------------------===//
245240

246-
extension Tensor {
241+
extension Tensor where Scalar : Differentiable & FloatingPoint {
247242
@inlinable
248243
func _adjointReshaped(
249244
seed: Tensor, originalValue: Tensor, toShape newShape: Tensor<Int32>
@@ -265,9 +260,8 @@ extension Tensor {
265260
// Normalization
266261
//===----------------------------------------------------------------------===//
267262

268-
extension Tensor where Scalar : BinaryFloatingPoint,
269-
Scalar : Differentiable,
270-
Scalar.CotangentVector == Scalar {
263+
extension Tensor where Scalar : BinaryFloatingPoint & Differentiable,
264+
Scalar == Scalar.CotangentVector {
271265
// TODO: Verify that these calculations are correct.
272266
@inlinable
273267
func _adjointBatchNormalized(
@@ -304,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint,
304298
// Convolution and pooling
305299
//===----------------------------------------------------------------------===//
306300

307-
extension Tensor where Scalar : BinaryFloatingPoint {
301+
extension Tensor where Scalar : Differentiable & FloatingPoint {
308302
/// TensorFlow builtin conv2d gradient helper for the input.
309303
@inlinable
310304
@differentiable(
@@ -448,7 +442,7 @@ extension Tensor where Scalar : BinaryFloatingPoint {
448442
//===----------------------------------------------------------------------===//
449443

450444
@inlinable
451-
func _adjointRelu<T : BinaryFloatingPoint>(
445+
func _adjointRelu<T : Differentiable & FloatingPoint>(
452446
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
453447
) -> Tensor<T> {
454448
return Tensor(x .> 0) * seed

0 commit comments

Comments
 (0)