-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add differentiable_function
canonicalization.
#30818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoDiff upstream] Add differentiable_function
canonicalization.
#30818
Conversation
Canonicalizes `differentiable_function` instructions by filling in missing derivative function operands. Derivative function emission rules, based on the original function value: - `function_ref`: look up differentiability witness with the exact or a minimal superset derivative configuration. Emit a `differentiability_witness_function` for the derivative function. - `witness_method`: emit a `witness_method` with the minimal superset derivative configuration for the derivative function. - `class_method`: emit a `class_method` with the minimal superset derivative configuration for the derivative function. If an *actual* emitted derivative function has a superset derivative configuration versus the *desired* derivative configuration, create a "subset parameters thunk" to thunk the actual derivative to the desired type. For `differentiable_function` instructions formed from curry thunk applications: clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new version with type `(Self) -> @differentiable (T, ...) -> U`. Progress towards TF-1211.
std::tie(thunk, interfaceSubs) = | ||
getOrCreateSubsetParametersThunkForDerivativeFunction( | ||
fb, origFnOperand, derivativeFn, derivativeFnKind, desiredIndices, | ||
actualIndices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: subset parameters thunk tests are blocked by derivative function emitters, which will be upstreamed next.
/// Check whether the given requirements are satisfied, with the given | ||
/// derivative generic signature (containing requirements), and substitution | ||
/// map. Returns true if error is emitted. | ||
static bool diagnoseUnsatisfiedRequirements(ADContext &context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Todo: upstream differentiation_transform_diagnostics.swift
, which tests diagnoseUnsatisfiedRequirements
.
@gottesmm: would you like to help review this patch, as a SILOptimizer code owner? |
@swift-ci Please test |
Build failed |
@swift-ci Please test Linux |
Merging to unblock progress. Happy to address any feedback later! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dan-zheng some comments. Please fix:
Also, are these used anywhere but the differentiation pass? If not, we should just put it in the differentiation pass until we find multiple uses. A bunch of the helpers that you added need to be refactored. Please fix them. I may look at this again in a bit.
/// Emit a zero value into the given buffer access by calling | ||
/// `AdditiveArithmetic.zero`. The given type must conform to | ||
/// `AdditiveArithmetic`. | ||
void emitZeroIntoBuffer(SILBuilder &builder, CanType type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix these APIs to not take a SILBuilder. Instead, it should take an insertion point and a SILBuilderContext.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems important for correctness. It can probably help clean up this code:
https://github.com/apple/swift/blob/13d5a8addbe3605984edc4ce7c6cbf1f0649a9e1/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp#L486-L497
I started looking into this but ran into some difficulties. The user above wants to insert at the end of a basic block (SILBasicBlock::iterator
), but no appropriate constructor exists. Only SILBuilderWithScope(SILInstruction *I, SILBuilderContext &C)
is recommended.
SILBuilderWithScope(SILBasicBlock::iterator I, SILBuilderContext &C)
would work but doesn't exist. Can we add it, or are there other considerations?
@@ -33,6 +34,12 @@ namespace autodiff { | |||
/// This is being used to print short debug messages within the AD pass. | |||
raw_ostream &getADDebugStream(); | |||
|
|||
/// Given a function call site, gathers all of its actual results (both direct | |||
/// and indirect) in an order defined by its result type. | |||
void collectAllActualResultsInTypeOrder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a number of problems with this. I am pretty sure we already have something like this.
auto diffResultFnTy = resultFnTy->getWithExtInfo( | ||
resultFnTy->getExtInfo().withDifferentiabilityKind( | ||
DifferentiabilityKind::Normal)); | ||
auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a common pattern here, maybe put the construction of this function onto a helper on SILOptFunctionBuilder? The hope with that was to hide the low level SILFunction creation functions in favor of higher level routines so we don't have to have huge invocations of SILFunctionType::get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code here takes a $(Self) -> (T, ...) -> U
function (curry thunk) and clones it into a $(Self) -> @differentiable (T, ...) -> U
function.
SILOptFunctionBuilder
looks very minimal, so I'm not sure adding an entry point makes sense. We should add a type calculation helper to SILFunctionType
, but that seems ad-hoc because this is the only user.
Let me know if you have further suggestions!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dan-zheng I wrote that code and the reason why it is minimal now is that I was trying to create the scaffolding for what I just mentioned. Just no one has had the time to fix it up old uses. That being said, we should not make the problem worse.
} | ||
} | ||
|
||
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed? Could we put this onto SILBuilder or something? Seems wrong to have such helper utilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would give it a name that makes it explicit that you are using a tuple. Otherwise, how do I know I am getting back a tuple?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressing in #30970 (comment), let's continue discussion there.
return builder.createTuple(loc, elements); | ||
} | ||
|
||
void extractAllElements(SILValue value, SILBuilder &builder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is void SILBuilder::emitDestructureValueOperation(SILLocation loc, SILValue operand,
SmallVectorImpl &result);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressing in #30970 (comment), let's continue discussion there.
results.push_back(builder.createTupleExtract(value.getLoc(), value, i)); | ||
} | ||
|
||
void emitZeroIntoBuffer(SILBuilder &builder, CanType type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This again doesn't feel like a method that has anything to do with differentiation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this code currently emits AdditiveArithmetic.zero
, but we plan to change it to emit Differentiable.zeroTangentVector
sometime (TF-1005), so I'd prefer to keep it differentiation-specific.
Callers of this function could be deduped though: PullbackEmitter::emitZeroIndirect
and JVPEmitter::emitZeroIndirect
. I can look into that, but I'm not sure how to best use SILBuilderContext &
since a SILBasicBlock::iterator
is needed as an insertion point: #30818 (comment)
Would moving all the contents of I would very much like to avoid moving everything into a single |
Thanks @gottesmm! I'll create a PR addressing your feedback, we can continue discussion there. |
Canonicalizes
differentiable_function
instructions by filling in missingderivative function operands.
Derivative function emission rules, based on the original function value:
function_ref
: look up differentiability witness with the exact or a minimalsuperset derivative configuration. Emit a
differentiability_witness_function
for the derivative function.
witness_method
: emit awitness_method
with the minimal superset derivativeconfiguration for the derivative function.
class_method
: emit aclass_method
with the minimal superset derivativeconfiguration for the derivative function.
If an actual emitted derivative function has a superset derivative
configuration versus the desired derivative configuration, create a "subset
parameters thunk" to thunk the actual derivative to the desired type.
For
differentiable_function
instructions formed from curry thunk applications:clone the curry thunk (with type
(Self) -> (T, ...) -> U
) and create a newversion with type
(Self) -> @differentiable (T, ...) -> U
.Progress towards TF-1211.