Skip to content

Commit 0547eb5

Browse files
njjiangrxwei
authored andcommitted
[AutoDiff] Declare differential-producing differential operators. (#25697)
Add differential operators in the [design overview](http://bit.ly/swift-autodiff) that haven't been declared and reorder the declarations to match the design overview.
1 parent ea578a8 commit 0547eb5

File tree

1 file changed

+167
-41
lines changed

1 file changed

+167
-41
lines changed

stdlib/public/core/AutoDiff.swift

Lines changed: 167 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,37 @@ public extension Differentiable {
336336
// Free-function-style differential operators
337337
//===----------------------------------------------------------------------===//
338338

339+
// Transpose
340+
341+
@available(*, unavailable)
342+
@inlinable
343+
public func transpose<T, R>(
344+
of body: @escaping @differentiable/*(linear)*/ (T) -> R
345+
) -> @differentiable/*(linear)*/ (R) -> T {
346+
fatalError()
347+
}
348+
349+
// Value with differential
350+
351+
@available(*, unavailable)
352+
@inlinable
353+
public func valueWithDifferential<T, R>(
354+
at x: T, in body: @differentiable (T) -> R
355+
) -> (value: R, differential: @differentiable/*(linear)*/
356+
(T.TangentVector) -> R.TangentVector) {
357+
fatalError()
358+
}
359+
360+
@available(*, unavailable)
361+
@inlinable
362+
public func valueWithDifferential<T, U, R>(
363+
at x: T, _ y: U, in body: @differentiable (T, U) -> R
364+
) -> (value: R,
365+
differential: @differentiable/*(linear)*/
366+
(T.TangentVector, U.TangentVector) -> R.TangentVector) {
367+
fatalError()
368+
}
369+
339370
// Value with pullback
340371

341372
@inlinable
@@ -362,6 +393,25 @@ public func valueWithPullback<T, U, V, R>(
362393
return Builtin.autodiffApply_vjp_arity3(f, x, y, z)
363394
}
364395

396+
// Differential
397+
398+
@available(*, unavailable)
399+
@inlinable
400+
public func differential<T, R>(
401+
at x: T, in body: @differentiable(T) -> R
402+
) -> @differentiable/*(linear)*/ (T.TangentVector) -> R.TangentVector {
403+
fatalError()
404+
}
405+
406+
@available(*, unavailable)
407+
@inlinable
408+
public func differential<T, U, R>(
409+
at x: T, _ y: U, in body: @differentiable(T, U) -> R
410+
) -> @differentiable/*(linear)*/ (T.TangentVector, U.TangentVector)
411+
-> R.TangentVector {
412+
fatalError()
413+
}
414+
365415
// Pullback
366416

367417
@inlinable
@@ -386,6 +436,70 @@ public func pullback<T, U, V, R>(
386436
return Builtin.autodiffApply_vjp_arity3(f, x, y, z).1
387437
}
388438

439+
// Derivative
440+
441+
@available(*, unavailable)
442+
@inlinable
443+
public func derivative<T: FloatingPoint, R>(
444+
at x: T, in body: @escaping @differentiable (T) -> R
445+
) -> R.TangentVector where T.TangentVector : FloatingPoint {
446+
fatalError()
447+
}
448+
449+
@available(*, unavailable)
450+
@inlinable
451+
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
452+
at x: T, _ y: U, in body: @escaping @differentiable (T) -> R
453+
) -> R.TangentVector where T.TangentVector : FloatingPoint {
454+
fatalError()
455+
}
456+
457+
// Gradient
458+
459+
@inlinable
460+
public func gradient<T, R>(
461+
at x: T, in f: @differentiable (T) -> R
462+
) -> T.TangentVector
463+
where R : FloatingPoint, R.TangentVector == R {
464+
return pullback(at: x, in: f)(R(1))
465+
}
466+
467+
@inlinable
468+
public func gradient<T, U, R>(
469+
at x: T, _ y: U, in f: @differentiable (T, U) -> R
470+
) -> (T.TangentVector, U.TangentVector)
471+
where R : FloatingPoint, R.TangentVector == R {
472+
return pullback(at: x, y, in: f)(R(1))
473+
}
474+
475+
@inlinable
476+
public func gradient<T, U, V, R>(
477+
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
478+
) -> (T.TangentVector, U.TangentVector, V.TangentVector)
479+
where R : FloatingPoint, R.TangentVector == R {
480+
return pullback(at: x, y, z, in: f)(R(1))
481+
}
482+
483+
// Value with derivative
484+
485+
@available(*, unavailable)
486+
@inlinable
487+
public func valueWithDerivative<T: FloatingPoint, R>(
488+
at x: T, in body: @escaping @differentiable (T) -> R
489+
) -> (value: R, derivative: R.TangentVector)
490+
where T.TangentVector : FloatingPoint {
491+
fatalError()
492+
}
493+
494+
@available(*, unavailable)
495+
@inlinable
496+
public func valueWithDerivative<T: FloatingPoint, U: FloatingPoint, R>(
497+
at x: T, _ y: U, in body: @escaping @differentiable (T) -> R
498+
) -> (value: R, derivative: R.TangentVector)
499+
where T.TangentVector : FloatingPoint {
500+
fatalError()
501+
}
502+
389503
// Value with gradient
390504

391505
@inlinable
@@ -416,84 +530,96 @@ public func valueWithGradient<T, U, V, R>(
416530
return (y, pullback(R(1)))
417531
}
418532

419-
// Value with gradient (curried)
533+
// Derivative (curried)
534+
535+
@available(*, unavailable)
536+
@inlinable
537+
public func derivative<T: FloatingPoint, R>(
538+
of body: @escaping @differentiable (T) -> R
539+
) -> (T) -> R.TangentVector where T.TangentVector : FloatingPoint {
540+
fatalError()
541+
}
542+
543+
@available(*, unavailable)
544+
@inlinable
545+
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
546+
of body: @escaping @differentiable (T, U) -> R
547+
) -> (T, U) -> R.TangentVector where T.TangentVector : FloatingPoint {
548+
fatalError()
549+
}
550+
551+
// Gradient (curried)
420552

421553
@inlinable
422-
public func valueWithGradient<T, R>(
554+
public func gradient<T, R>(
423555
of f: @escaping @differentiable (T) -> R
424-
) -> (T) -> (value: R, gradient: T.TangentVector)
556+
) -> (T) -> T.TangentVector
425557
where R : FloatingPoint, R.TangentVector == R {
426-
return { x in valueWithGradient(at: x, in: f) }
558+
return { x in gradient(at: x, in: f) }
427559
}
428560

429561
@inlinable
430-
public func valueWithGradient<T, U, R>(
562+
public func gradient<T, U, R>(
431563
of f: @escaping @differentiable (T, U) -> R
432-
) -> (T, U) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
564+
) -> (T, U) -> (T.TangentVector, U.TangentVector)
433565
where R : FloatingPoint, R.TangentVector == R {
434-
return { x, y in valueWithGradient(at: x, y, in: f) }
566+
return { x, y in gradient(at: x, y, in: f) }
435567
}
436568

437569
@inlinable
438-
public func valueWithGradient<T, U, V, R>(
570+
public func gradient<T, U, V, R>(
439571
of f: @escaping @differentiable (T, U, V) -> R
440-
) -> (T, U, V)
441-
-> (value: R,
442-
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
572+
) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector)
443573
where R : FloatingPoint, R.TangentVector == R {
444-
return { x, y, z in valueWithGradient(at: x, y, z, in: f) }
574+
return { x, y, z in gradient(at: x, y, z, in: f) }
445575
}
446576

447-
// Gradient
448-
449-
@inlinable
450-
public func gradient<T, R>(
451-
at x: T, in f: @differentiable (T) -> R
452-
) -> T.TangentVector
453-
where R : FloatingPoint, R.TangentVector == R {
454-
return pullback(at: x, in: f)(R(1))
455-
}
577+
// Value with derivative (curried)
456578

579+
@available(*, unavailable)
457580
@inlinable
458-
public func gradient<T, U, R>(
459-
at x: T, _ y: U, in f: @differentiable (T, U) -> R
460-
) -> (T.TangentVector, U.TangentVector)
461-
where R : FloatingPoint, R.TangentVector == R {
462-
return pullback(at: x, y, in: f)(R(1))
581+
public func valueWithDerivative<T: FloatingPoint, R>(
582+
of body: @escaping @differentiable (T) -> R
583+
) -> (value: R, derivative: R.TangentVector)
584+
where T.TangentVector: FloatingPoint {
585+
fatalError()
463586
}
464587

588+
@available(*, unavailable)
465589
@inlinable
466-
public func gradient<T, U, V, R>(
467-
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
468-
) -> (T.TangentVector, U.TangentVector, V.TangentVector)
469-
where R : FloatingPoint, R.TangentVector == R {
470-
return pullback(at: x, y, z, in: f)(R(1))
590+
public func valueWithDerivative<T: FloatingPoint, U: FloatingPoint, R>(
591+
of body: @escaping @differentiable (T, U) -> R
592+
) -> (value: R, derivative: R.TangentVector)
593+
where T.TangentVector: FloatingPoint, U.TangentVector: FloatingPoint {
594+
fatalError()
471595
}
472596

473-
// Gradient (curried)
597+
// Value with gradient (curried)
474598

475599
@inlinable
476-
public func gradient<T, R>(
600+
public func valueWithGradient<T, R>(
477601
of f: @escaping @differentiable (T) -> R
478-
) -> (T) -> T.TangentVector
602+
) -> (T) -> (value: R, gradient: T.TangentVector)
479603
where R : FloatingPoint, R.TangentVector == R {
480-
return { x in gradient(at: x, in: f) }
604+
return { x in valueWithGradient(at: x, in: f) }
481605
}
482606

483607
@inlinable
484-
public func gradient<T, U, R>(
608+
public func valueWithGradient<T, U, R>(
485609
of f: @escaping @differentiable (T, U) -> R
486-
) -> (T, U) -> (T.TangentVector, U.TangentVector)
610+
) -> (T, U) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
487611
where R : FloatingPoint, R.TangentVector == R {
488-
return { x, y in gradient(at: x, y, in: f) }
612+
return { x, y in valueWithGradient(at: x, y, in: f) }
489613
}
490614

491615
@inlinable
492-
public func gradient<T, U, V, R>(
616+
public func valueWithGradient<T, U, V, R>(
493617
of f: @escaping @differentiable (T, U, V) -> R
494-
) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector)
618+
) -> (T, U, V)
619+
-> (value: R,
620+
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
495621
where R : FloatingPoint, R.TangentVector == R {
496-
return { x, y, z in gradient(at: x, y, z, in: f) }
622+
return { x, y, z in valueWithGradient(at: x, y, z, in: f) }
497623
}
498624

499625
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)