-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Revamp 'struct_extract' differentiation strategy. #25151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Revamp 'struct_extract' differentiation strategy. #25151
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Could you please make the same change in VJPEmitter::visitStructElementAddrInst
?
@dan-zheng done! 😄 |
@swift-ci please test tensorflow Linux |
This solution makes sense for all normal cases. However, there's one case which I don't think it handles correctly: If the stored property has a custom VJP (shown below), the custom VJP should be used for differentiating that property instead of going along the fieldwise differentiation route. struct A : Differentiable & AdditiveArithmetic {
@differentiable(vjp: vjpPropertyA)
var a: Float = 1
typealias TangentVector = A
typealias AllDifferentiableVariables = A
func vjpPropertyA() -> (Float, (Float) -> A) {
(.zero, { _ in .zero })
}
}
@differentiable
func f(_ x: A) -> Float {
return x.a * 2
} // A.a.getter
sil hidden [transparent] [differentiable source 0 wrt 0 jvp @AD__$s4test1AV1aSfvg__jvp_src_0_wrt_0 vjp @$s4test1AV12vjpPropertyASf_ACSfctyF] @$s4test1AV1aSfvg : $@convention(method) (A) -> Float {
// %0 // users: %2, %1
bb0(%0 : $A):
debug_value %0 : $A, let, name "self", argno 1 // id: %1
%2 = struct_extract %0 : $A, #A.a // user: %3
return %2 : $Float // id: %3
} // end sil function '$s4test1AV1aSfvg' I experimented with this code and found out that when the getter has a custom However, if this PR were merged, the property in the example above would be differentiated incorrectly because the shortcut introduced in this PR makes the compiler not use the user-defined VJP. So in addition to the existing change, I would suggest adding a check to see whether the property getter decl has a auto *getterDecl = seai->getField()->getGetter(); // NOTE: Bring these two lines above the `if`.
assert(getterDecl);
if (!getterDecl->hasAttribute<DifferentiableAttr> &&
(structDecl->getEffectiveAccess() <= AccessLevel::Internal ||
structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>())) {
...
} |
@rxwei great catch! I'll make this change and also add a test case for this as well. |
d3951e0
to
7897fac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think there are existing tests that define a VJP for a stored property. Have you run tests locally to see whether they should be updated/removed?
Since we changed the solution to the problem, could you update the PR description? We should also make clear in the PR description that custom VJPs for stored property are banned.
No not yet! Was going to rerun tests before Chrome Remote Desktop was too much of a pain that I had to push a commit here first so i could access it from home. |
b0c063b
to
ab8cd94
Compare
@@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, | |||
"layout requirement are not supported by '@differentiable' attribute", ()) | |||
ERROR(differentiable_attr_class_unsupported,none, | |||
"class members cannot be marked with '@differentiable'", ()) | |||
ERROR(differentiable_attr_stored_property_unsupported,none, | |||
"Stored properties cannot be marked with '@differentiable'", ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diagnostic messages should start with a lowercase letter (like other ones).
Seems as though there are stored property tests in |
@swift-ci please test tensorflow |
…hr808/swift into TF-523-property-not-differentiable
@@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, | |||
"layout requirement are not supported by '@differentiable' attribute", ()) | |||
ERROR(differentiable_attr_class_unsupported,none, | |||
"class members cannot be marked with '@differentiable'", ()) | |||
ERROR(differentiable_attr_stored_property_variable_unsupported,none, | |||
"stored properties/variables cannot be marked with '@differentiable' with a custom VJP/JVP", ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if anyone has a better idea for a message. As we need to account for global variables/constants not being allowed to have custom VJPs/JVPs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- "VJP" or "JVP" does not formally exist in AD's user vocabulary, and we've been steering the programming model away from these words. In this case, since it's unavoidable to mention these terms since the user specified the
jvp:
orvjp:
label, it's better to just refer to them verbatim, i.e.'jvp:' or 'vjp:' cannot be specified for stored properties
. No need to mention'@differentiable'
in this case since the source location is at the right spot. - Thanks for catching the
/variables
case!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching the /variables case!
Update on this: Actually it seems as though a different error is being thrown in TypeCheckAttr.cpp
:
// Global immutable vars, for example, have no getter, and therefore trigger
// this.
if (!original) {
diagnoseAndRemoveAttr(attr, diag::invalid_decl_attribute, attr);
return;
}
Currently the error '@differentiable' attribute cannot be applied to this declaration
is being thrown on both global var
variables and global let
constants. To me, it seems as though we should be able to mark global variables as differentiable (not constants though).
@@ -2733,6 +2733,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, | |||
"layout requirement are not supported by '@differentiable' attribute", ()) | |||
ERROR(differentiable_attr_class_unsupported,none, | |||
"class members cannot be marked with '@differentiable'", ()) | |||
ERROR(differentiable_attr_stored_property_variable_unsupported,none, | |||
"stored properties/variables cannot be marked with '@differentiable' with a custom VJP/JVP", ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- "VJP" or "JVP" does not formally exist in AD's user vocabulary, and we've been steering the programming model away from these words. In this case, since it's unavoidable to mention these terms since the user specified the
jvp:
orvjp:
label, it's better to just refer to them verbatim, i.e.'jvp:' or 'vjp:' cannot be specified for stored properties
. No need to mention'@differentiable'
in this case since the source location is at the right spot. - Thanks for catching the
/variables
case!
lib/Sema/TypeCheckAttr.cpp
Outdated
if (asd->getImplInfo().isSimpleStored() && | ||
(attr->getJVP() || attr->getVJP())) { | ||
diagnoseAndRemoveAttr(attr, | ||
diag::differentiable_attr_stored_property_variable_unsupported); | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indention.
@swift-ci please test tensorflow |
1 similar comment
@swift-ci please test tensorflow |
// Differentiate the `struct_extract` by looking up the corresponding getter | ||
// and using its VJP. | ||
Getter | ||
Fieldwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole enum can be nuked now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about when a struct extract instruction is not active?
auto &strategies = context.getStructExtractDifferentiationStrategies();
// Special handling logic only applies when the `struct_extract` is active.
// If not, just do standard cloning.
if (!activityInfo.isActive(sei, getIndices())) {
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
strategies.insert(
{sei, StructExtractDifferentiationStrategy::Inactive});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anything directly involving strategies can be nuked I believe.
Hmmm seems like the code elimination may have broken two tests....I'll investigate this as it was working up to when I eliminated code I think |
Actually I was wrong, it seems to have been failing if I look at historical runs of running all the tests 😞 |
let y: Float | ||
|
||
@differentiable | ||
var y: Float | ||
|
||
func vjpY() -> (Float, (Float) -> TangentSpace) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These VJP functions are not used and should be removed.
@@ -70,7 +69,7 @@ E2EDifferentiablePropertyTests.test("stored property") { | |||
|
|||
struct GenericMemberWrapper<T : Differentiable> : Differentiable { | |||
// Stored property. | |||
@differentiable(vjp: vjpX) | |||
@differentiable | |||
var x: T | |||
|
|||
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) { | |||
struct VJPStruct { | |||
let p: Float | |||
|
|||
@differentiable(vjp: storedPropVJP) | |||
let storedImmutableOk: Float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For type checking tests like this, I'd actually suggest turning them into computed properties. These tests are pretty important.
Seems like it's failing on this reproducer: struct TangentSpace : VectorNumeric {
let dy: Float // change to `y` and it will work
}
extension TangentSpace : Differentiable {
typealias TangentVector = TangentSpace
}
struct Space {
@differentiable
var y: Float
}
extension Space : Differentiable {
typealias TangentVector = TangentSpace
func moved(along: TangentSpace) -> Space {
return Space(y: y + along.dy)
}
}
let actualGrad2 = gradient(at: Space(y: 0)) { (point: Space) -> Float in
return 3 * point.y
}
let expectedGrad2 = TangentSpace(dy: 3) What seems to be happening is in the field wise differentiation pass, it splits into two cases:
It's going into the else case and failing to find the tangent of the field for when the This seems like a bug in the field differentiation pass for struct extract instructions. |
This is an interesting test case, and I realized I missed one important edge case, which is when the |
- Remove `@_fieldwiseDifferentiable`. - Remove `StructExtractDifferentiationStrategy`. - Require `TangentVector` to have a member of the same name.
In d0f77e7, I fixed the issue and further simplified the property differentiation semantics by requiring a name correspondence between fields in Case where a user defined struct Foo: Differentiable {
var x: Float
struct TangentVector: Differentiable, AdditiveArithmetic {
var x: Float
}
...
}
@differentiable
func funcToDiff(_ foo: Foo) -> Float {
foo.x
} Case where a user defined struct Foo: Differentiable {
var x: Float
// No field named "x"!
struct TangentVector: Differentiable, AdditiveArithmetic {
var y: Float
}
...
}
@differentiable
func funcToDiff(_ foo: Foo) -> Float {
foo.x
}
// error: function is not differentiable
// note: property cannot be differentiated because 'Foo.TangentVector' does not have a member named 'x' This fully resolves TF-523 and TF-262. Feel free to reword this comment and make it part of the PR description. |
@swift-ci please test tensorflow |
3 similar comments
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
I'm a bit confused regarding how this fixes TF-262, and it also seems that all of TF-21's subtasks are done and closed, but case AdjointValueKind::Aggregate: {
// FIXME(TF-21): If `TangentVector` is not marked
// `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer.
// for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements()))
// addAdjointValue(std::get<0>(pair), std::get<1>(pair));
llvm_unreachable("Unhandled. Are you trying to differentiate a "
"memberwise initializer?");
} Could we eliminate this as well? |
TF-262 is about
No. This is not directly related to TF-21. Feel free to remove the comment. I'll start a PR to handle this soon. |
Resolves TF-523.
The Bug
SILGen sometimes does not create a getter for a stored property. User-defined derivatives (VJPs) for such properties are not attached to the original function, causing the differentiation transform unable to find or use any custom derivative through the SIL
[differentiable]
attribute on the original getter function. Lots of property differentiation issues arose because of this.History/Reason Why Bug Exists
Previously, the semantics of automatic differentiation treated stored properties exactly as if they were computed properties. In particular, we allowed stored properties to have user-defined derivatives. This was also required because
TangentVector
could be entirely user-defined, in which case the compiler would not be able to figure out which stored property in aT.TangentVector
corresponds to a stored property inT
. As such, we made it so that all differentiation of stored property accesses semantically go through a derivative function defined for the property (#21567), with an efficient shortcut which does not go through derivative functions but uses astruct
instruction to create aTangentVector
when the structure is marked with@_fieldwiseDifferentiable
(#21575, #21737, #21863, and #22009).For functions, initializer and computed properties, a reference to the user-defined derivative is attached to the original function in SIL. However, for stored properties, a getter function may not exist. Also, it is quite cumbersome for a user to define a custom VJP for a stored property and it's unclear anyone would want to do it.
Solution
The solution can be broken into two parts:
TangentVector
contains a stored property of the same name.Ban user-defined derivatives for stored properties
This step makes it possible to use the fieldwise
struct_extract
differentiation strategy (i.e. forming aTangentVector
using astruct
instruction) for allTangentVector
s that are mathematically the tangent space of the original product space.Previously, in order to support user-defined derivatives for properties in the same module, we had another
struct_extract
differentiation strategy that worked by replacing thestruct_extract
call with a call to the derivative of the synthesized property getter. However, we had to add a lot of hacky if statements to make this work, and even then it didn't work perfectly. A bunch of bugs and unknown SIL modifications came up like:@differentiable
and when the struct is not@_fieldwiseDifferentiable
.fileprivate
, or whenswift
is run instead ofswiftc
.Require
T.TangentVector
fields to correspond to fields inT
In this commit, we changed so now we require a name correspondence between fields in
T.TangentVector
and fields inT
, and deprecated the@_fieldwiseProductSpace
attribute. From now on, in order to differentiate a property accessfoo.x
,type(of: foo).TangentVector
must contain a corresponding property namedx
. If not, a diagnostic will be emitted. This is easier to reason about than the existing behavior, and is significantly easier for the user to define custom tangent spaces.Case where a user defined
TangentVector
makes it possible to differentiatefoo.x
:Case where a user defined
TangentVector
does not allow differentiatingfoo.x
because it does not have a field namedx
.Implementation
vjp:
orjvp:
in a@differentiable
attribute on a stored property.AdjointEmitter::visitStructExtractInst
andAdjointEmitter::visitStructElementAddrInst
differentiate fieldwise, forming aTangentVector
using astruct
instruction.@_fieldwiseDifferentiable
everywhere.StructExtractDifferentiationStrategy
and all related logic.TangentVector
does not contain a property with the requested name that corresponds to a property in the original type, emit a diagnosticproperty cannot be differentiated because 'Foo.TangentVector' does not have a member named 'x'
.ADContext::emitNondifferentiabilityError
methods into templates that take diagnostics with arbitrary arguments. This makes it possible for us to emit finer-grained diagnostics in the future.This change resulted in a 300-LOC code removal, which we should celebrate.
Tests
Had to fix a lot of tests, since a lot of them were relying on defining a custom VJP on stored properties.
Other Tickets/PRs
This fix would unblock me on the Complex Number PR and also fix TF-262.