Skip to content

Mark Differentiable related array methods with inlinable for big performance boost #75778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions stdlib/public/Differentiation/ArrayDifferentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,29 @@ extension Array where Element: Differentiable {
/// multiplied with itself `count` times.
@frozen
public struct DifferentiableView {
@usableFromInline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beware, adding @usableFromInline introduces a newly exported symbol without availability. It would not be safe to make this change in a module that's expected to have stable ABI. Newly built code after this change may not be binary compatible with Differentiation in earlier Swift releases.

(However, as I understand it, while Differentation is being build with library evolution enabled, it is not distributed as such on any platform where Swift is ABI stable. It is also unclear if the problematic stored property accessor exports are ever actually called in practice for a @frozen structure like this.)

Copy link
Contributor Author

@JaapWijnen JaapWijnen Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @lorentey !
Just to double check, this is considered unsafe because the struct was already decorated with @frozen? Or is using @usableFromInline without public availability generally ABI unstable? I might not fully grasp the subtlety here but would love to understand better to keep these in mind for future changes.
But indeed Differentiation is not yet distributed to a platform that is ABI stable, so it's not an issue yet!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei in your opinion, are we good to go here? As @lorentey points out we break ABI here but that's currently not an issue yet right? I'll add some comments to the main PR message regarding ABI stability.

As an additional question, just for my understanding. Is the following change binary incompatible? And if so why? It seems to me that we're only adding information to the interface not changing the binary layout of the struct. But I also don't have a lot of experience here so would like to understand the details if possible!

@frozen 
struct Thing {
    var storage: Float
}

->

@frozen 
struct Thing {
    @usableFromInline
    var storage: Float
}

Also tagging @asl

var _base: [Element]
}
}

extension Array.DifferentiableView: Differentiable
where Element: Differentiable {
/// The viewed array.
@inlinable
public var base: [Element] {
get { _base }
_modify { yield &_base }
}

@usableFromInline
@inlinable
@derivative(of: base)
func _vjpBase() -> (
value: [Element], pullback: (Array<Element>.TangentVector) -> TangentVector
) {
return (base, { $0 })
}

@usableFromInline
@inlinable
@derivative(of: base)
func _jvpBase() -> (
value: [Element], differential: (Array<Element>.TangentVector) -> TangentVector
Expand All @@ -50,17 +52,18 @@ where Element: Differentiable {
}

/// Creates a differentiable view of the given array.
@inlinable
public init(_ base: [Element]) { self._base = base }

@usableFromInline
@inlinable
@derivative(of: init(_:))
static func _vjpInit(_ base: [Element]) -> (
value: Array.DifferentiableView, pullback: (TangentVector) -> TangentVector
) {
return (Array.DifferentiableView(base), { $0 })
}

@usableFromInline
@inlinable
@derivative(of: init(_:))
static func _jvpInit(_ base: [Element]) -> (
value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector
Expand All @@ -71,6 +74,7 @@ where Element: Differentiable {
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView

@inlinable
public mutating func move(by offset: TangentVector) {
if offset.base.isEmpty {
return
Expand All @@ -88,6 +92,7 @@ where Element: Differentiable {

extension Array.DifferentiableView: Equatable
where Element: Differentiable & Equatable {
@inlinable
public static func == (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
Expand All @@ -98,6 +103,7 @@ where Element: Differentiable & Equatable {

extension Array.DifferentiableView: ExpressibleByArrayLiteral
where Element: Differentiable {
@inlinable
public init(arrayLiteral elements: Element...) {
self.init(elements)
}
Expand All @@ -123,10 +129,12 @@ extension Array.DifferentiableView: CustomReflectable {
extension Array.DifferentiableView: AdditiveArithmetic
where Element: AdditiveArithmetic & Differentiable {

@inlinable
public static var zero: Array.DifferentiableView {
return Array.DifferentiableView([])
}


@inlinable
public static func + (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
Expand All @@ -143,6 +151,7 @@ where Element: AdditiveArithmetic & Differentiable {
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
}

@inlinable
public static func - (
lhs: Array.DifferentiableView,
rhs: Array.DifferentiableView
Expand Down Expand Up @@ -180,6 +189,7 @@ extension Array: Differentiable where Element: Differentiable {
public typealias TangentVector =
Array<Element.TangentVector>.DifferentiableView

@inlinable
public mutating func move(by offset: TangentVector) {
var view = DifferentiableView(self)
view.move(by: offset)
Expand All @@ -192,7 +202,7 @@ extension Array: Differentiable where Element: Differentiable {
//===----------------------------------------------------------------------===//

extension Array where Element: Differentiable {
@usableFromInline
@inlinable
@derivative(of: subscript)
func _vjpSubscript(index: Int) -> (
value: Element, pullback: (Element.TangentVector) -> TangentVector
Expand All @@ -207,7 +217,7 @@ extension Array where Element: Differentiable {
return (self[index], pullback)
}

@usableFromInline
@inlinable
@derivative(of: subscript)
func _jvpSubscript(index: Int) -> (
value: Element, differential: (TangentVector) -> Element.TangentVector
Expand All @@ -218,7 +228,7 @@ extension Array where Element: Differentiable {
return (self[index], differential)
}

@usableFromInline
@inlinable
@derivative(of: +)
static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> (
value: Self,
Expand All @@ -241,7 +251,7 @@ extension Array where Element: Differentiable {
return (lhs + rhs, pullback)
}

@usableFromInline
@inlinable
@derivative(of: +)
static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> (
value: Self,
Expand All @@ -261,7 +271,7 @@ extension Array where Element: Differentiable {


extension Array where Element: Differentiable {
@usableFromInline
@inlinable
@derivative(of: append)
mutating func _vjpAppend(_ element: Element) -> (
value: Void, pullback: (inout TangentVector) -> Element.TangentVector
Expand All @@ -274,7 +284,7 @@ extension Array where Element: Differentiable {
})
}

@usableFromInline
@inlinable
@derivative(of: append)
mutating func _jvpAppend(_ element: Element) -> (
value: Void,
Expand All @@ -286,7 +296,7 @@ extension Array where Element: Differentiable {
}

extension Array where Element: Differentiable {
@usableFromInline
@inlinable
@derivative(of: +=)
static func _vjpAppend(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, pullback: (inout TangentVector) -> TangentVector
Expand All @@ -302,7 +312,7 @@ extension Array where Element: Differentiable {
})
}

@usableFromInline
@inlinable
@derivative(of: +=)
static func _jvpAppend(_ lhs: inout Self, _ rhs: Self) -> (
value: Void, differential: (inout TangentVector, TangentVector) -> Void
Expand All @@ -313,7 +323,7 @@ extension Array where Element: Differentiable {
}

extension Array where Element: Differentiable {
@usableFromInline
@inlinable
@derivative(of: init(repeating:count:))
static func _vjpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, pullback: (TangentVector) -> Element.TangentVector
Expand All @@ -326,7 +336,7 @@ extension Array where Element: Differentiable {
)
}

@usableFromInline
@inlinable
@derivative(of: init(repeating:count:))
static func _jvpInit(repeating repeatedValue: Element, count: Int) -> (
value: Self, differential: (Element.TangentVector) -> TangentVector
Expand Down