-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AD] Rewrite some of the tests with Tracked<Float> (2) #27767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@swift-ci please test tensorflow |
return pullback(at: x, in: f)(1) | ||
// Differential operators for `Tracked<T>`. | ||
|
||
public func gradient<T, U>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Differential operators in this file can actually be deleted once Tracked
conforms to FloatingPoint
. Were you working in this direction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC,Tracked: FloatingPoint
conformance caused *
operator lookup for @differentiating(*)
to become ambiguous.
That seems workaround-able by using @differentiable(vjp: ...)
for now. We should probably investigate fixing @differentiating(*)
ambiguous lookup (and @differentiating
original declaration lookup for initializers/subscripts/properties) sometime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxwei, yes it would be good to make Tracked
conform to FloatingPoint
, but have issues that Dan mentions. I had already filed a bug: https://bugs.swift.org/browse/TF-926
// MethodTests.testWithLeakChecking( | ||
// "instance method with generated adjoint, wrt self and non-self" | ||
// ) { | ||
// let g = #gradient({ (p: Parameter, o: Tracked<Float>) in p.multiplied(with: o) }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment: some tests are outdated:
- Reference to outdated
#gradient
expression. - Reference to "adjoint" functions, which are now called "pullback "functions.
- Deprecated comment; multiple
@differentiable
attributes are now possible:
// There is currently no way to define multiple custom VJPs wrt different
// parameters on the same func, so we define a copy of this func per adjoint.
Filed TF-929 to track updating outdated AutoDiff
tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I generally did not make changes in the comments, but this should have slipped through.
return pullback(at: x, in: f)(1) | ||
// Differential operators for `Tracked<T>`. | ||
|
||
public func gradient<T, U>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single gradient operator definitions for (...) -> Tracked<T>
functions is great!
For reference: TF-927 tracks the confusing type-checker error we encountered this morning.
Co-Authored-By: Richard Wei <[email protected]>
Enable more leak checking for the AD tests: https://bugs.swift.org/browse/TF-895