-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] [stdlib] Made arrays differentiable #23183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] [stdlib] Made arrays differentiable #23183
Conversation
Thanks Richard! I addressed all the feedback and tests seem to pass locally on my machine! :) |
Have you checked in those tests? I don’t see them in this PR. You can create a new file: test/AutodDiff/array.swift for tests. |
Sorry I haven't. I'll add a few more and push in a few minutes. Quick question: is there a way to recompile and run only the newly added tests? |
Yes! To do an incremental compile, cd to You can also run To run a specific test, you need to construct a |
Thanks a lot Marc! That's extremely useful! :) I made a couple simple tests, including the example you mentioned in the discussion group. I'll make sure they pass and commit them. |
LGTM, thanks for pushing this through! If you can check in the tests, that'd be great! |
Thanks for all the help! I'm about to commit a couple simple tests, but I can't seem to run them using the |
Here's an updated implementation that takes @rxwei's comments (above + some in-person discussion) into account: https://github.com/fastai/fastai_docs/blob/master/dev_swift/01c_array_differentiable.ipynb . (Also viewable as gist if the github ipynb viewer is broken: https://gist.github.com/marcrasi/b7618e3e7a8a8a920b73e662a40c4c7b#file-array-differentiable-swift) @eaplatanios, what do you think about this approach? |
@marcrasi Thanks lot for the update! I'll go through it later today, but just a quick question that pops in my mind is: what would the VJP of the |
|
@rxwei assuming that AD supports mutating functions and inout arguments, what would be the array length for the gradient? |
We can think of
The |
@eaplatanios, do you mind if I update this PR with these changes and try to get it merged soon? We're cutting a "v0.3" release of S4TF this evening, and we'd really like to have differentiable arrays in it! |
…/swift-language into eaplatanios-array-differentiable
@marcrasi Sorry it's been a super busy day so far. I took a look at your update and I agree it makes sense and like how we also avoid a lot of code duplication. Plus it makes more sense mathematically (I really like the last example you added in your earlier post). Feel free to update the PR and try to get it merged soon. I'm excited to see this making it in the 0.3 release. :) |
Let me know when it's ready to review! |
@rxwei This is ready to review! |
Note about thing I ran into: I also got the crash that @eaplatanios had. It seems related to cross-module stuff, which is why I didn't get the crash when I was playing with the gists. I fixed it by adding |
stdlib/public/core/Array.swift
Outdated
// TODO: Determine if that is a bug, and fix. | ||
public var array: [Element] { | ||
@differentiable(wrt: self, vjp: _vjpArray) | ||
get { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use _read
and _modify
accessors instead of get
/set
to make it more efficient.
get { | |
_read { yield base } | |
_modify { yield &base } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The @differentiable
attr doses not like _read
. It says:
/usr/local/google/home/marcrasi/swift-base/swift/stdlib/public/core/Array.swift:1888:8: error: cannot differentiate void function '_'
@differentiable(wrt: self, vjp: _vjpBase)
^
So I will leave it as get
/ _modify
for now, and this whole problem will go away once we fix the @differentiable
stored property issue.
// I'm implementing this as a computed property instead of directly exposing `base` because the | ||
// `@differentiable` annotation does not make the stored property actually differentiable. I | ||
// think this is a bug. Maybe it's related to `@_fixed_layout`? | ||
// TODO: Determine if that is a bug, and fix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. How about renaming the current base
to _base
, and renaming array
/_vjpArray
to base
/_vjpBase
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
|
||
/// Construct a view of the given array. | ||
@differentiable(wrt: array, vjp: _vjpInit) | ||
public init(_ array: [Element]) { self.base = array } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
public init(_ array: [Element]) { self.base = array } | |
public init(_ base: [Element]) { self.base = base } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
return (array, { $0 }) | ||
} | ||
|
||
/// Construct a view of the given array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Construct a view of the given array. | |
/// Creates a differentiable view of the given array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
|
||
// SWIFT_ENABLE_TENSORFLOW | ||
extension Array where Element : Differentiable { | ||
/// Views the array as the differentiable product manifold of `Element` multiplied with itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Views the array as the differentiable product manifold of `Element` multiplied with itself | |
/// The view of an array as the differentiable product manifold of `Element` multiplied with itself |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
get { | ||
return AllDifferentiableVariables(array.map { $0.allDifferentiableVariables }) | ||
} | ||
set(newValue) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Omit newValue
since the name is implicitly bound.
set(newValue) { | |
set { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
public func tangentVector(from cotangentVector: CotangentVector) -> TangentVector { | ||
precondition( | ||
array.count == cotangentVector.array.count, | ||
"cannot use Array.DifferentiableView with count \(array.count) to get tangentVector from " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these over 80 columns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, I thought the limit was 100 based on some existing lines that go over 80
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes sorry that was my mistake. I assumed it was 100 because that's in the Google Swift Style Guide.
stdlib/public/core/Array.swift
Outdated
get { | ||
return DifferentiableView(self).allDifferentiableVariables | ||
} | ||
set(v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set(v) { | |
set { |
Use newValue
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -104,6 +104,13 @@ extension Array : KeyPathIterable { | |||
} | |||
} | |||
|
|||
extension Array.DifferentiableView : KeyPathIterable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need a special implementation here. Derived conformances should just do it for you when you declare : KeyPathIterable
at DifferentiableView
's declaration site.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
stdlib/public/core/Array.swift
Outdated
} | ||
} | ||
|
||
public func _vjpArray() -> ([Element], (Array<Element>.CotangentVector) -> CotangentVector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make all VJPs be @usableFromInline internal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
The crash could be caused by |
@swift-ci please test tensorflow |
5 similar comments
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@eaplatanios Conforming |
Wow this is really cool! On a side note, I also run into the array literal differentiability issue. I was wondering, how can we add a VJP for that? |
It requires tweaking the AD transform to recognize |
This is a first attempt at making arrays differentiable.