-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] devirtualize diff witnesses #28480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@swift-ci please test tensorflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I like your point that "devirtualization" isn't an apt name because there's no virtual dispatch.
Running the pass only with -O
sounds good.
|
||
bool DifferentiabilityWitnessInliner:: | ||
inlineDifferentiabilityWitnessesInFunction(SILFunction &F) { | ||
bool Changed = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: how about consistently using camelcase spelling for variables in differentiable programming code? I don't feel too strongly as long as casing is locally consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
auto *W = I->getWitness(); | ||
if (W->isDeclaration() && !F.getModule().loadDifferentiabilityWitness(W)) | ||
continue; | ||
assert(W->isDefinition()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to set Changed
to true here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
SILMod, *linkage, original, parameterIndices, resultIndices, | ||
derivativeGenSig, jvp, vjp, isSerialized); | ||
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true); | ||
if (diffWitness->isDeclaration() && !isDeclaration) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain when this condition is true and convertToDefinition
is called?
It doesn't seem wholly obvious, perhaps an explanatory comment would be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Could you please comment on how this patch impacts |
There is runtime dynamism. The way it works is like protocol witness methods: there's no virtual dispatch, but there is dispatch -- the derivative is fetched at runtime. SIL devirtualizer applies to witness methods even though there's no virtual dispatch, so I think "devirtualizer" totally applies to differentiability witness functions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the similarity to the SIL devirtualizer, maybe it's easier to just add some logic in swift::tryDevirtualizeApply
that handles DifferentiabilityWitnessFunction
instructions.
This sounds good, I will rename it to devirtualizer
Currently this would be pretty involved because the SIL devirtualizer starts at apply methods and looks for the callee. A lot of the callees in differentiation cases are hidden behind differentiable_function and differentiable_function_extracts. The pass as written handles the devirtualization at the reference sites, avoiding this difficulty.
In combination with #28451, |
Makes sense. We need to do a proper |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow macos |
@swift-ci please test tensorflow |
3 similar comments
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow |
@swift-ci please test tensorflow macos |
1 similar comment
@swift-ci please test tensorflow macos |
Adds an optimization pass that devirtualizes differentiability witnesses into functions that reference them, replacing
differentiability_witness_function
s withfunction_ref
s.Resolves TF-919 and TF-994.
Performance impact
This completely eliminates the performance impact of #28451 under
-O
, except for cross-module non-serialized differentiability witnesses, by causing it to generate the same code that would be generated by tensorflow HEAD.I confirmed this by measuring the microbenchmark posted at #28451 (comment) . I haven't experimentally confirmed that the google internal model is also fixed, but I will experimentally verify that before I merge #2845.