-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Support @differentiable
class initializers.
#29754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4aaaa5b
to
43041cc
Compare
Support `@differentiable` class initializers. - Add special logic in `SILGenModule::getOrCreateCustomDerivativeThunk` to thunk user-defined class initializer derivative functions to the expected derivative type for class initializers. - Make minor vtable SILGen changes for class initializer derivatives. Also support differentiation for class-related casting instructions: `unchecked_ref_cast` and `upcast`. These instructions are generated when calling `super.init` or referencing inherited superclass members. Resolves SR-12151 and SR-12153.
43041cc
to
e6c9886
Compare
var base: Tracked<Float> | ||
// Dummy to make `Super.AllDifferentiableVariables` be nontrivial. | ||
var _nontrivial: [Tracked<Float>] = [] | ||
|
||
// TODO(TF-654): Remove attribute when differentiation supports class initializers. | ||
// @differentiable(vjp: vjpInit) | ||
@differentiable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I added upcast
and unchecked_ref_cast
differentiation support in this patch because they're needed to make this superclass initializer as @differentiable
.
Otherwise, inherited initializers encounter non-differentiability errors at weird source locatiosn for unchecked_ref_cast
and upcast
instructions.
@@ -792,6 +792,20 @@ Type TypeBase::replaceCovariantResultType(Type newResultType, | |||
return OptionalType::get( | |||
objectType->replaceCovariantResultType(newResultType, uncurryLevel)); | |||
} | |||
// SWIFT_ENABLE_TENSORFLOW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this added logic in TypeBase::replaceCovariantResultType
is hacky. It's necessary for TypeConverter::getConstantOverrideInfo
(derivative vtable entry SILGen) to work properly.
I can't think of a more robust alternative that doesn't involve ~60 lines of code dupe.
One alternative is to drop support for @differentiable
initializers in non-final classes.
Initializers return a dynamic Self
type, so perhaps non-final initializer derivatives shouldn't be supported (since derivatives would return a tuple (Self, ...)
and type-checking accepts covariant Self
only as a top-level method result type).
If there are no use cases for @differentiable
initializers in non-final classes, I think dropping support is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should ban @differentiable
on everything that produces a covariant result.
Verifying whether tests pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add a test case where a convenience initializer takes an address-only parameter? For example:
final class Foo<T> {
...
@differentiable
convenience init(_ x: T) {
self.init(...)
}
}
Fix `ref_element_addr` adjoint buffer type remapping. Garden tests.
// Consider adding a `bool isAutoDiffDerivative` flag for triggering this | ||
// logic, or creating a separate dedicated helper | ||
// `TypeBase::replaceCovariantResultTypeForAutoDiffDerivative`. | ||
if (auto tupleType = dyn_cast<TupleType>(this)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in TypeBase
. getAs()
would be more idiomatic.
if (auto tupleType = dyn_cast<TupleType>(this)) { | |
if (auto tupleType = getAs<TupleType>()) { |
Test added in d7b12c2.
That sounds good. Not supporting |
Support
@differentiable
class initializers.SILGenModule::getOrCreateCustomDerivativeThunk
tothunk user-defined class initializer derivative functions to the expected
derivative type for class initializers.
Also support differentiation for class-related casting instructions:
unchecked_ref_cast
andupcast
.These instructions are generated when calling
super.init
or referencinginherited superclass members.
Resolves SR-12151 and SR-12153.