Skip to content

Commit 0cd6e2a

Browse files
author
Marc Rasi
committed
[AutoDiff] remove the last vjp:s from the stdlib
1 parent c23afc8 commit 0cd6e2a

File tree

3 files changed

+46
-20
lines changed

3 files changed

+46
-20
lines changed

stdlib/public/core/Array.swift

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,6 @@ extension Array: RandomAccessCollection, MutableCollection {
693693
/// bridged `NSArray` instance as its storage, in which case writing is
694694
/// O(*n*), where *n* is the length of the array.
695695
@inlinable
696-
// SWIFT_ENABLE_TENSORFLOW
697-
@differentiable(wrt: self, vjp: _vjpSubscript where Element : Differentiable)
698696
public subscript(index: Int) -> Element {
699697
get {
700698
// This call may be hoisted or eliminated by the optimizer. If
@@ -873,8 +871,6 @@ extension Array: RangeReplaceableCollection {
873871
/// `repeating` parameter. `count` must be zero or greater.
874872
@inlinable
875873
@_semantics("array.init")
876-
@differentiable(wrt: repeatedValue, vjp: _vjpInit(repeating:count:)
877-
where Element: Differentiable)
878874
public init(repeating repeatedValue: Element, count: Int) {
879875
var p: UnsafeMutablePointer<Element>
880876
(self, p) = Array._allocateUninitialized(count)
@@ -1330,8 +1326,6 @@ extension Array: RangeReplaceableCollection {
13301326
// operator in the same expression.
13311327
extension Array {
13321328
@inlinable
1333-
// SWIFT_ENABLE_TENSORFLOW
1334-
@differentiable(vjp: _vjpPlus where Element : Differentiable)
13351329
public static func + (lhs: Array, rhs: Array) -> Array {
13361330
var lhs = lhs
13371331
lhs.append(contentsOf: rhs)
@@ -1930,24 +1924,24 @@ extension Array {
19301924
extension Array.DifferentiableView : Differentiable where Element : Differentiable {
19311925
/// The viewed array.
19321926
public var base: [Element] {
1933-
@differentiable(wrt: self, vjp: _vjpBase)
19341927
get { return _base }
19351928
_modify { yield &_base }
19361929
}
19371930

19381931
@usableFromInline
1932+
@derivative(of: base)
19391933
func _vjpBase() ->
1940-
([Element], (Array<Element>.TangentVector) -> TangentVector) {
1934+
(value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector) {
19411935
return (base, { $0 })
19421936
}
19431937

19441938
/// Creates a differentiable view of the given array.
1945-
@differentiable(wrt: base, vjp: _vjpInit)
19461939
public init(_ base: [Element]) { self._base = base }
19471940

19481941
@usableFromInline
1942+
@derivative(of: init(_:))
19491943
static func _vjpInit(_ base: [Element]) ->
1950-
(Array.DifferentiableView, (TangentVector) -> TangentVector) {
1944+
(value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector) {
19511945
return (Array.DifferentiableView(base), { $0 })
19521946
}
19531947

@@ -2088,8 +2082,9 @@ extension Array : EuclideanDifferentiable
20882082
}
20892083

20902084
extension Array where Element : Differentiable {
2085+
@derivative(of: subscript)
20912086
public func _vjpSubscript(index: Int) ->
2092-
(Element, (Element.TangentVector) -> TangentVector)
2087+
(value: Element, pullback: (Element.TangentVector) -> TangentVector)
20932088
{
20942089
func pullback(_ gradientIn: Element.TangentVector) -> TangentVector {
20952090
var gradientOut = Array<Element.TangentVector>(
@@ -2101,8 +2096,9 @@ extension Array where Element : Differentiable {
21012096
return (self[index], pullback)
21022097
}
21032098

2099+
@derivative(of: +)
21042100
public static func _vjpPlus(_ lhs: [Element], _ rhs: [Element]) ->
2105-
([Element], (TangentVector) -> (TangentVector, TangentVector)) {
2101+
(value: [Element], pullback: (TangentVector) -> (TangentVector, TangentVector)) {
21062102
func pullback(_ gradientIn: TangentVector) ->
21072103
(TangentVector, TangentVector) {
21082104
precondition(
@@ -2122,6 +2118,7 @@ extension Array where Element : Differentiable {
21222118

21232119
extension Array where Element: Differentiable {
21242120
@usableFromInline
2121+
@derivative(of: init(repeating:count:))
21252122
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
21262123
value: Self, pullback: (TangentVector) -> Element.TangentVector
21272124
) {

stdlib/public/core/FloatingPoint.swift

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,19 +1680,13 @@ extension FloatingPoint {
16801680
}
16811681

16821682
@_transparent
1683-
// SWIFT_ENABLE_TENSORFLOW
1684-
@differentiable(wrt: self, vjp: _vjpSquareRoot
1685-
where Self : Differentiable, Self == Self.TangentVector)
16861683
public func squareRoot( ) -> Self {
16871684
var lhs = self
16881685
lhs.formSquareRoot( )
16891686
return lhs
16901687
}
16911688

16921689
@_transparent
1693-
/// SWIFT_ENABLE_TENSORFLOW
1694-
@differentiable(wrt: (self, lhs, rhs), vjp: _vjpAddingProduct
1695-
where Self : Differentiable, Self == Self.TangentVector)
16961690
public func addingProduct(_ lhs: Self, _ rhs: Self) -> Self {
16971691
var addend = self
16981692
addend.addProduct(lhs, rhs)
@@ -1741,16 +1735,18 @@ extension FloatingPoint where Self : Differentiable,
17411735
/// original result and pullback of `addingProduct` with respect to `self`,
17421736
/// `lhs` and `rhs`.
17431737
@inlinable
1738+
@derivative(of: addingProduct)
17441739
func _vjpAddingProduct(
17451740
_ lhs: Self, _ rhs: Self
1746-
) -> (Self, (Self) -> (Self, Self, Self)) {
1741+
) -> (value: Self, pullback: (Self) -> (Self, Self, Self)) {
17471742
return (addingProduct(lhs, rhs), { _ in (1, rhs, lhs) })
17481743
}
17491744

17501745
/// The vector-Jacobian product function of `squareRoot`. Returns the original
17511746
/// result and pullback of `squareRoot` with respect to `self`.
17521747
@inlinable // FIXME(sil-serialize-all)
1753-
func _vjpSquareRoot() -> (Self, (Self) -> Self) {
1748+
@derivative(of: squareRoot)
1749+
func _vjpSquareRoot() -> (value: Self, pullback: (Self) -> Self) {
17541750
let y = squareRoot()
17551751
return (y, { v in v / (2 * y) })
17561752
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-run-simple-swiftgyb
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
6+
var FloatingPointTests = TestSuite("FloatingPoint")
7+
8+
%for Self in ['Float', 'Double', 'Float80']:
9+
% if Self == 'Float80':
10+
#if !os(Windows) && (arch(i386) || arch(x86_64))
11+
% end
12+
13+
FloatingPointTests.test("${Self}.squareRoot") {
14+
expectEqual(${Self}(0.5), gradient(at: ${Self}(1), in: { $0.squareRoot() }))
15+
expectEqual(${Self}(0.25), gradient(at: ${Self}(4), in: { $0.squareRoot() }))
16+
}
17+
18+
FloatingPointTests.test("${Self}.addingProduct") {
19+
expectEqual(
20+
(${Self}(1), ${Self}(2), ${Self}(3)),
21+
gradient(
22+
at: ${Self}(10), ${Self}(3), ${Self}(2),
23+
in: { $0.addingProduct($1, $2) }
24+
)
25+
)
26+
}
27+
28+
% if Self == 'Float80':
29+
#endif
30+
% end
31+
% end
32+
33+
runAllTests()

0 commit comments

Comments
 (0)