Skip to content

Commit 96fffb2

Browse files
committed
[AutoDiff] [API] Change Tensor conditional conformance to Differentiable.
`Tensor` now conditionally conforms to `Differentiable` where `Scalar : Differentiable & FloatingPoint`. All `@differentiable` where clauses and adjoint definitions have been updated accordingly. Allow `@differentiable` where clause conformance requirements to protocol composition types.
1 parent cba5a2b commit 96fffb2

File tree

5 files changed

+87
-55
lines changed

5 files changed

+87
-55
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,14 +2345,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23452345

23462346
// Conformance requirements are valid if:
23472347
// - The first type is a generic type parameter type.
2348-
// - The second type is a protocol type.
2348+
// - The second type is a protocol type or protocol composition type.
23492349
case RequirementKind::Conformance:
23502350
if (diagnoseDifferentiableAttrIndirectGenericType(
23512351
attr->getLocation(), req.getFirstType(),
23522352
reqRepr->getSubjectRepr()))
23532353
return false;
23542354

2355-
if (!req.getSecondType()->is<ProtocolType>()) {
2355+
if (!req.getSecondType()->is<ProtocolType>() &&
2356+
!req.getSecondType()->is<ProtocolCompositionType>()) {
23562357
TC.diagnose(attr->getLocation(),
23572358
diag::differentiable_attr_non_protocol_type_constraint_req)
23582359
.highlight(reqRepr->getSourceRange());

stdlib/public/TensorFlow/CompositeMath.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,28 @@
1919
/// Computes `sigmoid` of the specified tensor element-wise.
2020
/// Specifically, computes `1 / (1 + exp(-x))`.
2121
@inlinable @inline(__always)
22-
public func sigmoid<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
22+
public func sigmoid<T>(_ x: Tensor<T>) -> Tensor<T>
23+
where T : Differentiable & FloatingPoint
24+
{
2325
return 1 / (1 + exp(-x))
2426
}
2527

2628
/// Computes `relu` of the specified tensor element-wise.
2729
/// Specifically, computes `max(0, x)`.
2830
@inlinable @inline(__always)
2931
@differentiable(adjoint: _adjointRelu(_:_:_:))
30-
public func relu<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
32+
public func relu<T>(_ x: Tensor<T>) -> Tensor<T>
33+
where T : Differentiable & FloatingPoint
34+
{
3135
return max(0, x)
3236
}
3337

3438
/// Computes the softmax of the specified tensor element-wise.
3539
/// Specifically, computes `exp(x) / exp(x).sum()`.
3640
@inlinable @inline(__always)
37-
public func softmax<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
41+
public func softmax<T>(_ x: Tensor<T>) -> Tensor<T>
42+
where T : Differentiable & FloatingPoint
43+
{
3844
let expx = exp(x)
3945
let sum = expx.sum()
4046
return expx / sum
@@ -43,7 +49,7 @@ public func softmax<T : FloatingPoint>(_ x: Tensor<T>) -> Tensor<T> {
4349
/// Computes the softmax of the specified tensor along the specified axis.
4450
/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`.
4551
@inlinable @inline(__always)
46-
public func softmax<T : FloatingPoint>(
52+
public func softmax<T : Differentiable & FloatingPoint>(
4753
_ x: Tensor<T>, alongAxis axis: Int32
4854
) -> Tensor<T> {
4955
let expx = exp(x)

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
// Elementwise binary
4747
//===----------------------------------------------------------------------===//
4848

49-
extension Tensor where Scalar : FloatingPoint {
49+
extension Tensor where Scalar : Differentiable & FloatingPoint {
5050
@inlinable
5151
static func _adjointAdd(
5252
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor, _ y: Tensor
@@ -81,7 +81,7 @@ extension Tensor where Scalar : FloatingPoint {
8181
}
8282

8383
@inlinable
84-
func _adjointMinMax<T : FloatingPoint>(
84+
func _adjointMinMax<T : Differentiable & FloatingPoint>(
8585
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
8686
) -> (Tensor<T>, Tensor<T>) {
8787
let denom = 1 + Tensor<T>(x .== y)
@@ -91,7 +91,7 @@ func _adjointMinMax<T : FloatingPoint>(
9191
}
9292

9393
@inlinable
94-
func _adjointPow<T : FloatingPoint>(
94+
func _adjointPow<T : Differentiable & FloatingPoint>(
9595
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>, _ y: Tensor<T>
9696
) -> (Tensor<T>, Tensor<T>) {
9797
return ((seed * y * pow(x, y-1)).unbroadcast(like: x),
@@ -102,7 +102,7 @@ func _adjointPow<T : FloatingPoint>(
102102
// Elementwise unary
103103
//===----------------------------------------------------------------------===//
104104

105-
extension Tensor where Scalar : FloatingPoint {
105+
extension Tensor where Scalar : Differentiable & FloatingPoint {
106106
@inlinable
107107
static func _adjointNegate(
108108
_ seed: Tensor, _ originalValue: Tensor, _ x: Tensor
@@ -112,90 +112,90 @@ extension Tensor where Scalar : FloatingPoint {
112112
}
113113

114114
@inlinable
115-
func _adjointLog<T : FloatingPoint>(
115+
func _adjointLog<T : Differentiable & FloatingPoint>(
116116
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
117117
) -> Tensor<T> {
118118
return seed / x
119119
}
120120

121121
@inlinable
122-
func _adjointSin<T : FloatingPoint>(
122+
func _adjointSin<T : Differentiable & FloatingPoint>(
123123
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
124124
) -> Tensor<T> {
125125
return seed * cos(x)
126126
}
127127

128128
@inlinable
129-
func _adjointCos<T : FloatingPoint>(
129+
func _adjointCos<T : Differentiable & FloatingPoint>(
130130
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
131131
) -> Tensor<T> {
132132
return -seed * sin(x)
133133
}
134134

135135
@inlinable
136-
func _adjointTan<T : FloatingPoint>(
136+
func _adjointTan<T : Differentiable & FloatingPoint>(
137137
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
138138
) -> Tensor<T> {
139139
return seed * (1 + originalValue.squared())
140140
}
141141

142142
@inlinable
143-
func _adjointSinh<T : FloatingPoint>(
143+
func _adjointSinh<T : Differentiable & FloatingPoint>(
144144
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
145145
) -> Tensor<T> {
146146
return seed * cosh(x)
147147
}
148148

149149
@inlinable
150-
func _adjointCosh<T : FloatingPoint>(
150+
func _adjointCosh<T : Differentiable & FloatingPoint>(
151151
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
152152
) -> Tensor<T> {
153153
return seed * sinh(x)
154154
}
155155

156156
@inlinable
157-
func _adjointTanh<T : FloatingPoint>(
157+
func _adjointTanh<T : Differentiable & FloatingPoint>(
158158
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
159159
) -> Tensor<T> {
160160
return seed * (1 - originalValue.squared())
161161
}
162162

163163
@inlinable
164-
func _adjointExp<T : FloatingPoint>(
164+
func _adjointExp<T : Differentiable & FloatingPoint>(
165165
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
166166
) -> Tensor<T> {
167167
return originalValue * seed
168168
}
169169

170170
@inlinable
171-
func _adjointCeil<T : FloatingPoint>(
171+
func _adjointCeil<T : Differentiable & FloatingPoint>(
172172
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
173173
) -> Tensor<T> {
174174
return Tensor(0).broadcast(like: x)
175175
}
176176

177177
@inlinable
178-
func _adjointFloor<T : FloatingPoint>(
178+
func _adjointFloor<T : Differentiable & FloatingPoint>(
179179
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
180180
) -> Tensor<T> {
181181
return Tensor(0).broadcast(like: x)
182182
}
183183

184184
@inlinable
185-
func _adjointSqrt<T : FloatingPoint>(
185+
func _adjointSqrt<T : Differentiable & FloatingPoint>(
186186
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
187187
) -> Tensor<T> {
188188
return seed / (2 * originalValue)
189189
}
190190

191191
@inlinable
192-
func _adjointRsqrt<T : FloatingPoint>(
192+
func _adjointRsqrt<T : Differentiable & FloatingPoint>(
193193
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
194194
) -> Tensor<T> {
195195
return -seed / 2 * pow(originalValue, 3)
196196
}
197197

198-
func _adjointSquared<T : FloatingPoint>(
198+
func _adjointSquared<T : Differentiable & FloatingPoint>(
199199
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
200200
) -> Tensor<T> {
201201
return 2 * x * seed
@@ -206,7 +206,7 @@ func _adjointSquared<T : FloatingPoint>(
206206
//===----------------------------------------------------------------------===//
207207

208208
@inlinable
209-
func _adjointMatmul<Scalar : FloatingPoint>(
209+
func _adjointMatmul<Scalar : Differentiable & FloatingPoint>(
210210
_ seed: Tensor<Scalar>, _ originalValue: Tensor<Scalar>,
211211
_ left: Tensor<Scalar>, _ right: Tensor<Scalar>
212212
) -> (Tensor<Scalar>, Tensor<Scalar>) {
@@ -217,7 +217,7 @@ func _adjointMatmul<Scalar : FloatingPoint>(
217217
// TODO: We have to define a custom adjoint on • because AD can't yet
218218
// differentiate generic methods. After AD can differentiate generic methods,
219219
// remove the custom adjoint.
220-
extension Tensor where Scalar : FloatingPoint {
220+
extension Tensor where Scalar : Differentiable & FloatingPoint {
221221
@inlinable
222222
static func _adjointMatmulOperator(seed: Tensor, originalValue: Tensor,
223223
lhs: Tensor, rhs: Tensor)
@@ -238,7 +238,7 @@ extension Tensor where Scalar : FloatingPoint {
238238
// Shape transformations
239239
//===----------------------------------------------------------------------===//
240240

241-
extension Tensor where Scalar : FloatingPoint {
241+
extension Tensor where Scalar : Differentiable & FloatingPoint {
242242
@inlinable
243243
func _adjointReshaped(
244244
seed: Tensor, originalValue: Tensor, toShape newShape: Tensor<Int32>
@@ -298,7 +298,7 @@ extension Tensor where Scalar : BinaryFloatingPoint & Differentiable,
298298
// Convolution and pooling
299299
//===----------------------------------------------------------------------===//
300300

301-
extension Tensor where Scalar : FloatingPoint {
301+
extension Tensor where Scalar : Differentiable & FloatingPoint {
302302
/// TensorFlow builtin conv2d gradient helper for the input.
303303
@inlinable
304304
@differentiable(
@@ -442,7 +442,7 @@ extension Tensor where Scalar : FloatingPoint {
442442
//===----------------------------------------------------------------------===//
443443

444444
@inlinable
445-
func _adjointRelu<T : FloatingPoint>(
445+
func _adjointRelu<T : Differentiable & FloatingPoint>(
446446
_ seed: Tensor<T>, _ originalValue: Tensor<T>, _ x: Tensor<T>
447447
) -> Tensor<T> {
448448
return Tensor(x .> 0) * seed

0 commit comments

Comments
 (0)