Skip to content

Commit 2efb6e1

Browse files
authored
Merge pull request #24896 from rxwei/infer-diffable
[AutoDiff] [Sema] Infer Differentiable conformance for @differentiable function parameters and results.
2 parents ebf21f8 + 3cf23ce commit 2efb6e1

File tree

3 files changed

+71
-50
lines changed

3 files changed

+71
-50
lines changed

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5311,6 +5311,23 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
53115311
return Action::Continue;
53125312
}
53135313

5314+
// SWIFT_ENABLE_TENSORFLOW
5315+
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
5316+
if (fnTy->getExtInfo().isDifferentiable()) {
5317+
auto *diffableProto = Builder.getASTContext()
5318+
.getProtocol(KnownProtocolKind::Differentiable);
5319+
auto constrainToDifferentiable = [&](Type typeToConstrain) {
5320+
Requirement req(RequirementKind::Conformance, typeToConstrain,
5321+
diffableProto->getDeclaredType());
5322+
Builder.addRequirement(req, source, nullptr);
5323+
};
5324+
for (auto &param : fnTy->getParams())
5325+
if (!param.isNonDifferentiable())
5326+
constrainToDifferentiable(param.getPlainType());
5327+
constrainToDifferentiable(fnTy->getResult());
5328+
}
5329+
}
5330+
53145331
if (!ty->isSpecialized())
53155332
return Action::Continue;
53165333

stdlib/public/core/AutoDiff.swift

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -213,29 +213,29 @@ public extension Differentiable {
213213

214214
public extension Differentiable {
215215
@inlinable
216-
func valueWithPullback<R : Differentiable>(
216+
func valueWithPullback<R>(
217217
in f: @differentiable (Self) -> R
218218
) -> (value: R, pullback: (R.TangentVector) -> TangentVector) {
219219
return Builtin.autodiffApply_vjp_arity1(f, self)
220220
}
221221

222222
@inlinable
223-
func pullback<R : Differentiable>(
223+
func pullback<R>(
224224
in f: @differentiable (Self) -> R
225225
) -> (R.TangentVector) -> TangentVector {
226226
return Builtin.autodiffApply_vjp_arity1(f, self).1
227227
}
228228

229229
@inlinable
230-
func gradient<R : Differentiable>(
230+
func gradient<R>(
231231
in f: @differentiable (Self) -> R
232232
) -> TangentVector
233233
where R : FloatingPoint, R.TangentVector == R {
234234
return self.pullback(in: f)(R(1))
235235
}
236236

237237
@inlinable
238-
func valueWithGradient<R : Differentiable>(
238+
func valueWithGradient<R>(
239239
in f: @differentiable (Self) -> R
240240
) -> (value: R, gradient: TangentVector)
241241
where R : FloatingPoint, R.TangentVector == R {
@@ -244,30 +244,30 @@ public extension Differentiable {
244244
}
245245

246246
@inlinable
247-
func valueWithPullback<T : Differentiable, R : Differentiable>(
247+
func valueWithPullback<T, R>(
248248
at x: T, in f: @differentiable (Self, T) -> R
249249
) -> (value: R,
250250
pullback: (R.TangentVector) -> (TangentVector, T.TangentVector)) {
251251
return Builtin.autodiffApply_vjp_arity2(f, self, x)
252252
}
253253

254254
@inlinable
255-
func pullback<T : Differentiable, R : Differentiable>(
255+
func pullback<T, R>(
256256
at x: T, in f: @differentiable (Self, T) -> R
257257
) -> (R.TangentVector) -> (TangentVector, T.TangentVector) {
258258
return Builtin.autodiffApply_vjp_arity2(f, self, x).1
259259
}
260260

261261
@inlinable
262-
func gradient<T : Differentiable, R : Differentiable>(
262+
func gradient<T, R>(
263263
at x: T, in f: @differentiable (Self, T) -> R
264264
) -> (TangentVector, T.TangentVector)
265265
where R : FloatingPoint, R.TangentVector == R {
266266
return self.pullback(at: x, in: f)(R(1))
267267
}
268268

269269
@inlinable
270-
func valueWithGradient<T : Differentiable, R : Differentiable>(
270+
func valueWithGradient<T, R>(
271271
at x: T, in f: @differentiable (Self, T) -> R
272272
) -> (value: R, gradient: (TangentVector, T.TangentVector))
273273
where R : FloatingPoint, R.TangentVector == R {
@@ -285,17 +285,15 @@ public extension Differentiable {
285285
@inlinable
286286
public func valueWithPullback<T, R>(
287287
at x: T, in f: @differentiable (T) -> R
288-
) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
289-
where T : Differentiable, R : Differentiable {
288+
) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
290289
return Builtin.autodiffApply_vjp(f, x)
291290
}
292291

293292
@inlinable
294293
public func valueWithPullback<T, U, R>(
295294
at x: T, _ y: U, in f: @differentiable (T, U) -> R
296295
) -> (value: R,
297-
pullback: (R.TangentVector) -> (T.TangentVector, U.TangentVector))
298-
where T : Differentiable, U : Differentiable, R : Differentiable {
296+
pullback: (R.TangentVector) -> (T.TangentVector, U.TangentVector)) {
299297
return Builtin.autodiffApply_vjp_arity2(f, x, y)
300298
}
301299

@@ -304,9 +302,7 @@ public func valueWithPullback<T, U, V, R>(
304302
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
305303
) -> (value: R,
306304
pullback: (R.TangentVector)
307-
-> (T.TangentVector, U.TangentVector, V.TangentVector))
308-
where T : Differentiable, U : Differentiable, V : Differentiable,
309-
R : Differentiable {
305+
-> (T.TangentVector, U.TangentVector, V.TangentVector)) {
310306
return Builtin.autodiffApply_vjp_arity3(f, x, y, z)
311307
}
312308

@@ -315,26 +311,22 @@ public func valueWithPullback<T, U, V, R>(
315311
@inlinable
316312
public func pullback<T, R>(
317313
at x: T, in f: @differentiable (T) -> R
318-
) -> (R.TangentVector) -> T.TangentVector
319-
where T : Differentiable, R : Differentiable {
314+
) -> (R.TangentVector) -> T.TangentVector {
320315
return Builtin.autodiffApply_vjp(f, x).1
321316
}
322317

323318
@inlinable
324319
public func pullback<T, U, R>(
325320
at x: T, _ y: U, in f: @differentiable (T, U) -> R
326-
) -> (R.TangentVector) -> (T.TangentVector, U.TangentVector)
327-
where T : Differentiable, U : Differentiable, R : Differentiable {
321+
) -> (R.TangentVector) -> (T.TangentVector, U.TangentVector) {
328322
return Builtin.autodiffApply_vjp_arity2(f, x, y).1
329323
}
330324

331325
@inlinable
332326
public func pullback<T, U, V, R>(
333327
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
334328
) -> (R.TangentVector)
335-
-> (T.TangentVector, U.TangentVector, V.TangentVector)
336-
where T : Differentiable, U : Differentiable, V : Differentiable,
337-
R : Differentiable {
329+
-> (T.TangentVector, U.TangentVector, V.TangentVector) {
338330
return Builtin.autodiffApply_vjp_arity3(f, x, y, z).1
339331
}
340332

@@ -344,8 +336,7 @@ public func pullback<T, U, V, R>(
344336
public func valueWithGradient<T, R>(
345337
at x: T, in f: @differentiable (T) -> R
346338
) -> (value: R, gradient: T.TangentVector)
347-
where T : Differentiable, R : FloatingPoint & Differentiable,
348-
R.TangentVector == R {
339+
where R : FloatingPoint, R.TangentVector == R {
349340
let (y, pullback) = valueWithPullback(at: x, in: f)
350341
return (y, pullback(R(1)))
351342
}
@@ -354,8 +345,7 @@ public func valueWithGradient<T, R>(
354345
public func valueWithGradient<T, U, R>(
355346
at x: T, _ y: U, in f: @differentiable (T, U) -> R
356347
) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
357-
where T : Differentiable, U : Differentiable,
358-
R : FloatingPoint & Differentiable, R.TangentVector == R {
348+
where R : FloatingPoint, R.TangentVector == R {
359349
let (y, pullback) = valueWithPullback(at: x, y, in: f)
360350
return (y, pullback(R(1)))
361351
}
@@ -365,8 +355,7 @@ public func valueWithGradient<T, U, V, R>(
365355
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
366356
) -> (value: R,
367357
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
368-
where T : Differentiable, U : Differentiable, V : Differentiable,
369-
R : FloatingPoint & Differentiable, R.TangentVector == R {
358+
where R : FloatingPoint, R.TangentVector == R {
370359
let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
371360
return (y, pullback(R(1)))
372361
}
@@ -377,18 +366,15 @@ public func valueWithGradient<T, U, V, R>(
377366
public func valueWithGradient<T, R>(
378367
of f: @escaping @differentiable (T) -> R
379368
) -> (T) -> (value: R, gradient: T.TangentVector)
380-
where T : Differentiable, R : FloatingPoint & Differentiable,
381-
R.TangentVector == R {
369+
where R : FloatingPoint, R.TangentVector == R {
382370
return { x in valueWithGradient(at: x, in: f) }
383371
}
384372

385373
@inlinable
386374
public func valueWithGradient<T, U, R>(
387375
of f: @escaping @differentiable (T, U) -> R
388376
) -> (T, U) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
389-
where T : Differentiable, U : Differentiable,
390-
R : FloatingPoint & Differentiable,
391-
R.TangentVector == R {
377+
where R : FloatingPoint, R.TangentVector == R {
392378
return { x, y in valueWithGradient(at: x, y, in: f) }
393379
}
394380

@@ -398,9 +384,7 @@ public func valueWithGradient<T, U, V, R>(
398384
) -> (T, U, V)
399385
-> (value: R,
400386
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
401-
where T : Differentiable, U : Differentiable, V : Differentiable,
402-
R : FloatingPoint & Differentiable,
403-
R.TangentVector == R {
387+
where R : FloatingPoint, R.TangentVector == R {
404388
return { x, y, z in valueWithGradient(at: x, y, z, in: f) }
405389
}
406390

@@ -410,26 +394,23 @@ public func valueWithGradient<T, U, V, R>(
410394
public func gradient<T, R>(
411395
at x: T, in f: @differentiable (T) -> R
412396
) -> T.TangentVector
413-
where T : Differentiable, R : FloatingPoint & Differentiable,
414-
R.TangentVector == R {
397+
where R : FloatingPoint, R.TangentVector == R {
415398
return pullback(at: x, in: f)(R(1))
416399
}
417400

418401
@inlinable
419402
public func gradient<T, U, R>(
420403
at x: T, _ y: U, in f: @differentiable (T, U) -> R
421404
) -> (T.TangentVector, U.TangentVector)
422-
where T : Differentiable, U : Differentiable,
423-
R : FloatingPoint & Differentiable, R.TangentVector == R {
405+
where R : FloatingPoint, R.TangentVector == R {
424406
return pullback(at: x, y, in: f)(R(1))
425407
}
426408

427409
@inlinable
428410
public func gradient<T, U, V, R>(
429411
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
430412
) -> (T.TangentVector, U.TangentVector, V.TangentVector)
431-
where T : Differentiable, U : Differentiable, V : Differentiable,
432-
R : FloatingPoint & Differentiable, R.TangentVector == R {
413+
where R : FloatingPoint, R.TangentVector == R {
433414
return pullback(at: x, y, z, in: f)(R(1))
434415
}
435416

@@ -439,28 +420,23 @@ public func gradient<T, U, V, R>(
439420
public func gradient<T, R>(
440421
of f: @escaping @differentiable (T) -> R
441422
) -> (T) -> T.TangentVector
442-
where T : Differentiable, R : FloatingPoint & Differentiable,
443-
R.TangentVector == R {
423+
where R : FloatingPoint, R.TangentVector == R {
444424
return { x in gradient(at: x, in: f) }
445425
}
446426

447427
@inlinable
448428
public func gradient<T, U, R>(
449429
of f: @escaping @differentiable (T, U) -> R
450430
) -> (T, U) -> (T.TangentVector, U.TangentVector)
451-
where T : Differentiable, U : Differentiable,
452-
R : FloatingPoint & Differentiable,
453-
R.TangentVector == R {
431+
where R : FloatingPoint, R.TangentVector == R {
454432
return { x, y in gradient(at: x, y, in: f) }
455433
}
456434

457435
@inlinable
458436
public func gradient<T, U, V, R>(
459437
of f: @escaping @differentiable (T, U, V) -> R
460438
) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector)
461-
where T : Differentiable, U : Differentiable, V : Differentiable,
462-
R : FloatingPoint & Differentiable,
463-
R.TangentVector == R {
439+
where R : FloatingPoint, R.TangentVector == R {
464440
return { x, y, z in gradient(at: x, y, z, in: f) }
465441
}
466442

test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,31 @@ func test2<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -
3535
func test3<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Int) {}
3636
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
3737
func test4<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Int) {}
38+
39+
let diffFunc: @differentiable (Float) -> Float
40+
func inferredConformances<T, U>(_: @differentiable (T) -> U) {}
41+
inferredConformances(diffFunc)
42+
43+
func inferredConformancesResult<T, U>() -> @differentiable (T) -> U {}
44+
45+
let diffFuncWithNondiff: @differentiable (Float, @nondiff Int) -> Float
46+
func inferredConformances<T, U, V>(_: @differentiable (T, @nondiff U) -> V) {}
47+
inferredConformances(diffFuncWithNondiff)
48+
49+
struct Vector<T> {
50+
var x, y: T
51+
}
52+
extension Vector: Differentiable where T: Differentiable {}
53+
54+
// expected-note @+2 {{where 'T' = 'Int'}}
55+
// expected-note @+1 {{where 'U' = 'Int'}}
56+
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
57+
58+
let nondiffVectorFunc: (Vector<Int>) -> Vector<Int>
59+
// expected-error @+1 2 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
60+
inferredConformancesGeneric(nondiffVectorFunc)
61+
62+
let diffVectorFunc: (Vector<Float>) -> Vector<Float>
63+
inferredConformancesGeneric(diffVectorFunc) // okay!
64+
65+
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}

0 commit comments

Comments
 (0)