@@ -72,6 +72,7 @@ Backticks were added manually.
72
72
* [ Upcasting to non-` @differentiable ` functions] ( #upcasting-to-non-differentiable-functions )
73
73
* [ Implied generic constraints] ( #implied-generic-constraints )
74
74
* [ Non-differentiable parameters] ( #non-differentiable-parameters )
75
+ * [ Higher-order functions and currying] ( #higher-order-functions-and-currying )
75
76
* [ Differential operators] ( #differential-operators )
76
77
* [ Differential-producing differential operators] ( #differential-producing-differential-operators )
77
78
* [ Pullback-producing differential operators] ( #pullback-producing-differential-operators )
@@ -88,7 +89,6 @@ Backticks were added manually.
88
89
* [ Convolutional neural networks (CNN)] ( #convolutional-neural-networks-cnn )
89
90
* [ Recurrent neural networks (RNN)] ( #recurrent-neural-networks-rnn )
90
91
* [ Future directions] ( #future-directions )
91
- * [ Differentiation of higher-order functions] ( #differentiation-of-higher-order-functions )
92
92
* [ Higher-order differentiation] ( #higher-order-differentiation )
93
93
* [ Naming conventions for numerical computing] ( #naming-conventions-for-numerical-computing )
94
94
* [ Source compatibility] ( #source-compatibility )
@@ -1452,8 +1452,11 @@ making the other function linear.
1452
1452
1453
1453
A protocol requirement or class method/property/subscript can be made
1454
1454
differentiable via a derivative function or transpose function defined in an
1455
- extension. A dispatched call to such a member can be differentiated even if the
1456
- concrete implementation is not differentiable.
1455
+ extension. When a protocol requirement is not marked with ` @differentiable ` but
1456
+ has been made differentiable by a ` @derivative ` or ` @transpose ` declaration in a
1457
+ protocol extension, a dispatched call to such a member can be differentiated,
1458
+ and the derivative or transpose is always the one provided in the protocol
1459
+ extension.
1457
1460
1458
1461
#### Linear maps
1459
1462
@@ -1696,48 +1699,49 @@ public protocol ElementaryFunctions {
1696
1699
...
1697
1700
}
1698
1701
1699
- public extension ElementaryFunctions where Self : Differentiable , Self == Self .TangentVector {
1702
+ public extension ElementaryFunctions
1703
+ where Self : Differentiable & FloatingPoint , Self == Self .TangentVector {
1700
1704
@inlinable
1701
1705
@derivative (of: sqrt)
1702
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) {
1706
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1703
1707
(sqrt (x), { dx in (1 / 2 ) * (1 / sqrt (x)) * dx })
1704
1708
}
1705
1709
1706
1710
@inlinable
1707
1711
@derivative (of: cos)
1708
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) {
1712
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1709
1713
(cos (x), { dx in - sin (x) * dx })
1710
1714
}
1711
1715
1712
1716
@inlinable
1713
1717
@derivative (of: asinh)
1714
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) {
1718
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1715
1719
(asinh (x), { dx in 1 / (1 + x * x) * dx })
1716
1720
}
1717
1721
1718
1722
@inlinable
1719
1723
@derivative (of: exp)
1720
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) {
1724
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1721
1725
let ret = exp (x)
1722
1726
return (ret, { dx in ret * dx })
1723
1727
}
1724
1728
1725
1729
@inlinable
1726
1730
@derivative (of: exp10)
1727
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) {
1731
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1728
1732
let ret = exp10 (x)
1729
1733
return (ret, { dx in exp (10 ) * ret * dx })
1730
1734
}
1731
1735
1732
1736
@inlinable
1733
1737
@derivative (of: log)
1734
- func _ (_ x : Self ) -> (value: Self , differential: @differential (linear) (Self ) -> Self ) { dx in
1735
- (log (x), { 1 / x * dx })
1738
+ static func _ (_ x : Self ) -> (value: Self , differential: @differentiable (linear) (Self ) -> Self ) {
1739
+ (log (x), { dx in 1 / x * dx })
1736
1740
}
1737
1741
1738
1742
@inlinable
1739
1743
@derivative (of: pow)
1740
- func _ (_ x : Self , _ y : Self ) -> (value: Self , differential: @differential (linear) (Self , Self ) -> Self ) {
1744
+ static func _ (_ x : Self , _ y : Self ) -> (value: Self , differential: @differentiable (linear) (Self , Self ) -> Self ) {
1741
1745
(pow (x, y), { (dx, dy) in
1742
1746
let l = y * pow (x, y- 1 ) * dx
1743
1747
let r = pow (x, y) * log (x) * dy
@@ -1749,6 +1753,73 @@ public extension ElementaryFunctions where Self: Differentiable, Self == Self.Ta
1749
1753
}
1750
1754
```
1751
1755
1756
+ #### Default derivatives
1757
+
1758
+ In a protocol extension, class definition, or class extension, providing a
1759
+ derivative or transpose for a protocol extension or a non-final class member is
1760
+ considered as providing a default derivative for that member. Types that conform
1761
+ to the protocol or inherit from the class can inherit the default derivative.
1762
+
1763
+ If the original member does not have a ` @differentiable ` attribute, a default
1764
+ derivative is implicitly added to all conforming/overriding implementations.
1765
+
1766
+ ``` swift
1767
+ protocol P {
1768
+ func foo (_ x : Float ) -> Float
1769
+ }
1770
+
1771
+ extension P {
1772
+ @derivative (of: foo (x: ))
1773
+ func _ (_ x : Float ) -> (value: Float , differential: (Float ) -> Float ) {
1774
+ (value : foo (x), differential : { _ in 42 })
1775
+ }
1776
+ }
1777
+
1778
+ struct S : P {
1779
+ func foo (_ x : Float ) -> Float {
1780
+ 33
1781
+ }
1782
+ }
1783
+
1784
+ let s = S ()
1785
+ let d = derivative (at : 0 ) { x in
1786
+ s.foo (x)
1787
+ } // ==> 42
1788
+ ```
1789
+
1790
+ When a protocol requirement or class member is marked with ` @differentiable ` , it
1791
+ is considered as a _ differentiability customization point_ . This means that all
1792
+ conforming/overriding implementation must provide a corresponding
1793
+ ` @differentiable ` attribute, which causes the implementation to be
1794
+ differentiated. To inherit the default derivative without differentiating the
1795
+ implementation, add ` default ` to the ` @differentiable ` attribute.
1796
+
1797
+ ``` swift
1798
+ protocol P {
1799
+ @differentiable
1800
+ func foo (_ x : Float ) -> Float
1801
+ }
1802
+
1803
+ extension P {
1804
+ @derivative (of: foo (x: ))
1805
+ func _ (_ x : Float ) -> (value: Float , differential: (Float ) -> Float ) {
1806
+ (value : foo (x), differential : { _ in 42 })
1807
+ }
1808
+ }
1809
+
1810
+ struct S : P {
1811
+ @differentiable (default ) // Inherits default derivative for `P.foo(_:)`.
1812
+ func foo (_ x : Float ) -> Float {
1813
+ 33
1814
+ }
1815
+ }
1816
+
1817
+ let s = S ()
1818
+ let d = derivative (at : 0 ) { x in
1819
+ s.foo (x)
1820
+ } // ==> 42
1821
+ ```
1822
+
1752
1823
### Differentiable function types
1753
1824
1754
1825
Differentiability is a fundamental mathematical concept that applies not only to
@@ -2002,6 +2073,43 @@ _ = f0 as @differentiable (@noDerivative Float, Float) -> Float
2002
2073
_ = f0 as @differentiable (@noDerivative Float , @noDerivative Float ) -> Float
2003
2074
```
2004
2075
2076
+ #### Higher-order functions and currying
2077
+
2078
+ As defined above, the ` @differentiable ` function type attributes requires all
2079
+ non-` @noDerivative ` arguments and results to conform to the ` @differentiable `
2080
+ attribute. However, there is one exception: when the type of an argument or
2081
+ result is a function type, e.g. `@differentiable (T) -> @differentiable (U) ->
2082
+ V`. This is because we need to differentiate higher-order funtions.
2083
+
2084
+ Mathematically, the differentiability of ` @differentiable (T, U) -> V ` is
2085
+ similar to that of ` @differentiable (T) -> @differentiable (U) -> V ` in that
2086
+ differentiating either one will provide derivatives with respect to parameters
2087
+ ` T ` and ` U ` . Here are some examples of first-order function types and their
2088
+ corresponding curried function types:
2089
+
2090
+ | First-order function type | Curried function type |
2091
+ | ---------------------------------------------| ---------------------------------------------------|
2092
+ | ` @differentiable (T, U) -> V ` | ` @differentiable (T) -> @differentiable (U) -> V ` |
2093
+ | ` @differentiable (T, @noDerivative U) -> V ` | ` @differentiable (T) -> (U) -> V ` |
2094
+ | ` @differentiable (@noDerivative T, U) -> V ` | ` (T) -> @differentiable (U) -> V ` |
2095
+
2096
+ A curried differentiable function can be formed like any curried
2097
+ non-differentiable function in Swift.
2098
+
2099
+ ``` swift
2100
+ func curry <T , U , V >(
2101
+ _ f : @differentiable (T, U) -> V
2102
+ ) -> @differentiable (T) -> @differentiable (U) -> V {
2103
+ { x in { y in f (x, y) } }
2104
+ }
2105
+ ```
2106
+
2107
+ The way this works is that the compiler internally assigns a tangent bundle to a
2108
+ closure that captures variables. This tangent bundle is existentially typed,
2109
+ because closure contexts are type-erased in Swift. The theory behind the typing
2110
+ rules has been published as [ The Differentiable
2111
+ Curry] ( https://www.semanticscholar.org/paper/The-Differentiable-Curry-Plotkin-Brain/187078bfb159c78cc8c78c3bbe81a9176b3a6e02 ) .
2112
+
2005
2113
### Differential operators
2006
2114
2007
2115
The core differentiation APIs are the differential operators. Differential
@@ -2021,7 +2129,7 @@ func valueWithDifferential<T, R>(
2021
2129
) -> (value: R,
2022
2130
differential: @differentiable (linear) (T.TangentVector) -> R.TangentVector) {
2023
2131
// Compiler built-in.
2024
- Builtin.autodiffApply_jvp_arity1 (body, x)
2132
+ Builtin.applyDerivative_arity1 (body, x)
2025
2133
}
2026
2134
2027
2135
@@ -2030,7 +2138,7 @@ func transpose<T, R>(
2030
2138
of body : @escaping @differentiable (linear) (T) -> R
2031
2139
) -> @differentiable (linear) (R) -> T {
2032
2140
// Compiler built-in.
2033
- { x in Builtin.autodiffApply_transpose (body, x) }
2141
+ { x in Builtin.applyTranspose_arity1 (body, x) }
2034
2142
}
2035
2143
```
2036
2144
@@ -2203,13 +2311,13 @@ whether the derivative is always zero and warns the user.
2203
2311
2204
2312
```swift
2205
2313
let grad = gradient (at : 1.0 ) { x in
2206
- 3 .squareRoot ()
2314
+ Double ( 3 ) .squareRoot ()
2207
2315
}
2208
2316
```
2209
2317
2210
2318
```console
2211
- test.swift : 4 : 18 : warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '. withoutDerivative ()' to make it explicit?
2212
- 3 .squareRoot ()
2319
+ test.swift : 4 : 18 : warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to use ' withoutDerivative (at: )' to make it explicit?
2320
+ Double ( 3 ) .squareRoot ()
2213
2321
^
2214
2322
withoutDerivative (at: )
2215
2323
```
@@ -2456,30 +2564,6 @@ typealias LSTM<Scalar: TensorFlowFloatingPoint> = RNN<LSTMCell<Scalar>>
2456
2564
2457
2565
## Future directions
2458
2566
2459
- ### Differentiation of higher- order functions
2460
-
2461
- Mathematically, the differentiability of `@differentiable (T, U) -> V` is
2462
- similar to that of `@differentiable (T) -> @differentiable (U) -> V` in that
2463
- differentiating either one will provide derivatives with respect to parameters
2464
- `T` and `U`.
2465
-
2466
- To form a `@differentiable (T) -> @differentiable (U) -> V`, the most natural
2467
- thing to do is currying, which one might implement as :
2468
-
2469
- ```swift
2470
- func curry< T, U, V> (
2471
- _ f: @differentiable (T, U) -> V
2472
- ) -> @differentiable (T) -> @differentiable (U) -> V {
2473
- { x in { y in f (x, y) } }
2474
- }
2475
- ```
2476
-
2477
- However, the compiler does not support currying today due to known
2478
- type- theoretical constraints and implementation complexity regarding
2479
- differentiating a closure with respect to the values it captures. Fortunately,
2480
- we have a formally proven solution in the works, but we would like to defer this
2481
- to a future proposal since it is purely additive to the existing semantics.
2482
-
2483
2567
### Higher- order differentiation
2484
2568
2485
2569
Distinct from differentiation of higher- order functions, higher- order
@@ -2528,9 +2612,7 @@ func valueWithDifferential<T: FloatingPoint, U: Differentiable>(
2528
2612
2529
2613
To differentiate `valueWithDifferential`, we need to be able to differentiate
2530
2614
its return value, a tuple of the original value and the differential, with
2531
- respect to its `x` argument. Since the return type contains a function,
2532
- [differentiation of higher- order functions](#differentiation - of- higher- order- functions)
2533
- is required for differentiating this differential operator .
2615
+ respect to its `x` argument.
2534
2616
2535
2617
A kneejerk solution is to differentiate derivative functions generated by the
2536
2618
differentiation transform at compile- time, but this leads to problems. For
0 commit comments