-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Support differentiation of switch_enum
.
#25509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff] Support differentiation of switch_enum
.
#25509
Conversation
Handle `switch_enum` terminator during VJP and adjoint generation. Necessary step for differentiating `for-in` loops, which contain optional iterator `next()` values. Diagnose differentiation of active enum values, which requires further adjoint generation support.
@@ -4155,6 +4215,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |||
auto addActiveValue = [&](SILValue v) { | |||
if (visited.count(v)) | |||
return; | |||
// Diagnose active enum values. Differentiation of enum values is not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diagnosing active enum values is necessary because adjoint generation doesn't propagate adjoint values of enum associated values correctly.
Support is non-trivial because switch_enum
operand and successor block arguments have different types: the operand has an enum type but successor block arguments have associated values' type. Adjoint value propagation needs to construct enum adjoint value from associated values' adjoint values.
@@ -1533,6 +1533,15 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, | |||
setVaried(cbi->getFalseBB()->getArgument(opIdx), i); | |||
} | |||
} | |||
// Handle `switch_enum`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems necessary to also handle switch_enum_addr
, but I didn't find Swift functions containing switch_enum_addr
that also don't contain active enum values. A quick searches shows that switch_enum_addr
is generated during later SIL passes and also for optional force-unwrapping.
@swift-ci Please test tensorflow |
https://ci-external.swift.org/job/swift-PR-TensorFlow-Linux is down, ran tests locally. $ ./swift/utils/build-script --preset tensorflow_test
...
Testing Time: 58.74s
Expected Passes : 10661
Expected Failures : 27
Unsupported Tests : 1385
-- check-swift-validation-linux-x86_64 finished --
Testing Time: 53.18s
Expected Passes : 1167
Expected Failures : 9
Unsupported Tests : 10897
-- check-swift-validation-optimize-linux-x86_64 finished -- Merging to unblock progress. Todo: add more tests. |
Handle
switch_enum
terminator during VJP and adjoint generation.Necessary step for differentiating
for-in
loops, which containoptional iterator
next()
values.Diagnose differentiation of active enum values, which requires
further adjoint generation support.
Example: