-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add @differentiable
declaration attribute type-checking.
#29231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Type-checking rules (summary): - `@differentiable` attribute must be declared on a function-like "original" declaration: `func`, `init`, `subscript`, `var` (computed properties only). - Parsed differentiability parameters must be valid (if they exist). - Parsed `where` clause must be valid (if it exists). - Differentiability parameters must all conform to `Differentiable`. - Original result must all conform to `Differentiable`. - If JVP/VJP functions are specified, they must match the expected type. - `@differentiable(jvp:vjp:)` for derivative registration is deprecated in favor of `@derivative` attribute, and will be removed soon. - Duplicate `@differentiable` attributes with the same differentiability parameters are invalid. - For protocol requirements and class members with `@differentiable` attribute, conforming types and subclasses must have the same `@differentiable` attribute (or one with a superset of differentiability parameter indices) on implementing/overriding declarations. Code changes: - Add `DifferentiableAttributeTypeCheckRequest`. - Currently, the request returns differentiability parameter indices, while also resolving `JVPFunction`, `VJPFunction`, and `DerivativeGenericSignature` and mutating them in-place in `DifferentiableAttr`. This works fine for now. - Add "is type-checked" bit to `DifferentiableAttr`. - Alternatively, I tried changing `CacheKind::SeparatelyCached` to `CacheKind::Cached`, but it did not seem to work: `@differentiable` attributes in non-primary-files were not type-checked for some reason. Upstream disorganized tests as-is from `tensorflow` branch. Resolves TF-828.
Let's see if tests pass. |
@@ -530,7 +530,114 @@ swift::matchWitness( | |||
} | |||
|
|||
// Now finalize the match. | |||
return finalize(anyRenaming, optionalAdjustments); | |||
auto result = finalize(anyRenaming, optionalAdjustments); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I anticipate that some code owners might care strongly about preserving return finalize(anyRenaming, optionalAdjustments)
here. It doesn't matter too much to me, so I upstreamed code from tensorflow
branch as-is.
Aside: this logic for diagnosing unmet @differentiable
attribute requirement doesn't work when associated type inference is involved in witness checking. TF-1014 tracks that issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I anticipate that some code owners might care strongly about preserving
return finalize(anyRenaming, optionalAdjustments)
here.
Slightly mitigating the above point: I refactored swift::matchWitness
logic for checking @differentiable
attributes into a matchWitnessDifferentiableAttr
helper. So the code changes to swift::matchWitness
are minimal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please clarify on the PR title that this is "Adding @differentiable
declaration attribute type-checking" to avoid any confusion with the @differentiable
function type attribute.
@differentiable
attribute type-checking.@differentiable
declaration attribute type-checking.
/// - `DerivativeGenericSignature` | ||
class DifferentiableAttributeTypeCheckRequest | ||
: public SimpleRequest<DifferentiableAttributeTypeCheckRequest, | ||
IndexSubset *(DifferentiableAttr *), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, DifferentiableAttributeTypeCheckRequest
takes a DifferentiableAttr *
and returns an IndexSubset *
representing parameter indices. It also resolves two FuncDecl *
(JVPFunction
and VJPFunction
) and a GenericSignature
(DerivativeGenericSignature
), mutating them in-place in the DifferentiableAttr *
.
This works fine for now. I thought of two ways to refactor the request to avoid mutation:
- Make
DifferentiableAttributeTypeCheckRequest
return a tuple/struct of the resolved components.- This caused many request cycles, since
DifferentiableAttr::get{JVPFunction,VJPFunction,DerivativeGenericSignature}
now all trigger the request. I haven't debugged further.
- This caused many request cycles, since
- Make an individual request for resolving each
@differentiable
attribute component: originalAbstractFuncDecl
(s), parameter indices, JVP/VJPFuncDecl
(if specified), derivativeGenericSignature
(if specified).- I haven't tried this. It seems like a fair amount of work, and the upside is unclear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
History: #28017 moved @differentiable
attribute type-checking logic from AttributeChecker::visitDifferentiableAttr
to DifferentiableAttributeTypeCheckRequest::evaluate
.
The current request approach (return IndexSubset *
but mutate other components in-place) is the first one that worked.
class DifferentiableAttributeTypeCheckRequest | ||
: public SimpleRequest<DifferentiableAttributeTypeCheckRequest, | ||
IndexSubset *(DifferentiableAttr *), | ||
CacheKind::SeparatelyCached> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using CacheKind::Cached
seems preferable so that DifferentiableAttr
is truly stateless, adhering to the request evaluator caching vision.
I tried changing CacheKind::SeparatelyCached
to CacheKind::Cached
(some time ago on tensorflow
branch), but it did not seem to work: @differentiable
attributes in non-primary-files were left unchecked.
I'm not sure why CacheKind::Cached
did not work. I did notice that quite a few other requests (e.g. InterfaceTypeRequest
) used CacheKind::SeparatelyCached
for some reason - perhaps that's just for request cycle breaking?
@CodaFi: would you like to review this PR, or suggest a reviewer? I'm asking you since you seem like a request evaluator owner. This patch adds a I created two discussion threads regarding possible improvements:
Any improvement suggestions would be appreciated! If you'd like me to split this patch into two smaller ones (idea), I'm happy to do so. Otherwise, I'll merge this patch within 2 days so that it's in a development snapshot release before our |
- Refactor `DifferentiableAttributeTypeCheckRequest::evaluate` into smaller helpers. - Consider rewriting helpers as their own requests. - Add explanatory comment for `DifferentiableAttr::ParameterIndicesAndBit`. - Refactor `swift::matchWitness` logic for checking `@differentiable` attributes into a `matchWitnessDifferentiableAttr` helper.
@swift-ci Please test |
Merging now, happy to address further feedback later! |
The
@differentiable
attribute marks a function as differentiable.Example:
The
@differentiable
attribute has an optionalwrt:
clause specifying theparameters that are differentiated "with respect to", i.e. the differentiability
parameters. The differentiability parameters must conform to the
Differentiable
protocol.If the
wrt:
clause is unspecified, the differentiability parameters arecurrently inferred to be all parameters that conform to
Differentiable
.The
@differentiable
attribute also has optionaljvp:
andvjp:
labelsfor registering derivative functions. These labels are deprecated in favor of
the
@derivative
attribute and will be removed soon.The
@differentiable
attribute also has an optionalwhere
clause, specifyingextra differentiability requirements for generic functions.
The
@differentiable
attribute is gated by the-enable-experimental-differentiable-programming
flag.Type-checking rules (summary):
@differentiable
attribute must be declared on a function-like "original"declaration:
func
,init
,subscript
,var
(computed properties only).where
clause must be valid (if it exists).Differentiable
.Differentiable
.@differentiable(jvp:vjp:)
for derivative registration is deprecated infavor of
@derivative
attribute, and will be removed soon.@differentiable
attributes with the same differentiabilityparameters are invalid.
@differentiable
attribute,conforming types and subclasses must have the same
@differentiable
attribute(or one with a superset of differentiability parameter indices) on
implementing/overriding declarations.
These rules are consistent with the differentiable programming manifesto.
The main proposed rules are implemented.
Code changes:
DifferentiableAttributeTypeCheckRequest
.while also resolving
JVPFunction
,VJPFunction
, andDerivativeGenericSignature
and mutating them in-place inDifferentiableAttr
. This works fine for now.DifferentiableAttr
.CacheKind::SeparatelyCached
toCacheKind::Cached
, but it did not seem to work:@differentiable
attributes in non-primary-files were left unchecked.
Upstream disorganized tests as-is from
tensorflow
branch.Resolves TF-828: upstream
@differentiable
attribute type-checking.