-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[API] [AD] Revamp @differentiable
usages in stdlib.
#21732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API] [AD] Revamp @differentiable
usages in stdlib.
#21732
Conversation
WIP, triggering tests to see what breaks. |
I think we should change extension Tensor : Differentiable where Scalar : FloatingPoint This kinda works, but what if, in a context where extension Tensor : Differentiable where Scalar : Differentiable & FloatingPoint |
- Use `FloatingPoint` rather than `BinaryFloatingPoint` to constrain differentiability. - Follows from: - swiftlang#21673 - tensorflow/swift-bindings#11 - Use `@differentiable` where clauses to constrain differentiability of numeric operations. - The most common constraint is `where Scalar : FloatingPoint` because `Tensor` conditionally conforms to `Differentiable where Scalar : FloatingPoint`. Todos: - Make more `Tensor` operations differentiable. - This includes reduction and broadcasting ops. - This is enabled by `@differentiable` where clause type-checking. - Use VJP functions instead of adjoint functions. - I would prefer that this be done in a separate patch, after this patch adds the correct `@differentiable` where clauses. - Add tests for newly `@differentiable` `Tensor` operations.
If a custom `@differentiable` attribute defines a VJP and where clause requirements, VJP applications should use a substitution map involving those requirements. Note: more related cases need to be handled, such as `@differentiable` attributes with where clause requirements but no VJP. These cases will be handled later.
aa7b5af
to
cba5a2b
Compare
This makes sense. I'll adopt this change in this PR. |
…tiable`. `Tensor` now conditionally conforms to `Differentiable` where `Scalar : Differentiable & FloatingPoint`. All `@differentiable` where clauses and adjoint definitions have been updated accordingly. Allow `@differentiable` where clause conformance requirements to protocol composition types.
214327a
to
96fffb2
Compare
@swift-ci Please test tensorflow |
Merging to unblock progress. |
Would you mind changing the PR description to mention the substitution map change? |
Done! Also added info about the last commit ( |
Yeah, this will be handled (and required) in differentiation through generics. |
FloatingPoint
rather thanBinaryFloatingPoint
to constraindifferentiability.
FloatingPoint
, notBinaryFloatingPoint
. #21673FloatingPoint
constraint insteaad ofBinaryFloatingPoint
. tensorflow/swift-bindings#11@differentiable
where clauses to constrain differentiabilityof numeric operations.
where Scalar : FloatingPoint
becauseTensor
conditionally conforms toDifferentiable where Scalar : FloatingPoint
.Tensor
now conditionally conforms toDifferentiable
whereScalar : Differentiable & FloatingPoint
.@differentiable
where clause conformance requirements to protocol composition types.@differentiable
attribute defines a VJP and where clauserequirements, VJP applications should use a substitution map involving
those requirements.
@differentiable
attributes with where clause requirements but no VJP. These cases will
be handled later.
Todos:
Tensor
operations differentiable.@differentiable
where clause type-checking.adds the correct
@differentiable
where clauses.@differentiable
Tensor
operations.