Skip to content

Commit 0a9181a

Browse files
authored
[AutoDiff] Revamp differentiation transform. (#24845)
- Directly generate primal code in VJP functions. - Rename `PrimalGenCloner` to `VJPEmitter`. - This unblocks control-flow support. Primal data structures can be generated based on activity analysis. - Refactor differentiation transform to use iterative `autodiff_function` instruction worklist instead of a fixed `DifferentiationTask` worklist. - `VJPEmitter::visitApply` now creates `autodiff_function` and `autodiff_function_extract` instructions instead of directly emitting associated function references. - An iterative loop canonicalizes `autodiff_function` instructions, promoting them to `@differentiable` function-typed values. - Add `autodiff_function_extract` folding optimization. - Fold `autodiff_function_extract` users of `autodiff_function` instructions, directly replacing them with operands of the `autodiff_function` instruction. - If the `autodiff_function` instruction has only `autodiff_function_extract` users, delete the instruction itself after folding. - Remove unnecessary auxiliary data structures. - `DifferentiationTask`: replace with map from `[differentiable]` attributes to `DifferentiationInvoker`. - `PrimalGen`, `AdjointGen`: replace with `autodiff_function` iterative loop.
1 parent d75a649 commit 0a9181a

15 files changed

+1838
-1554
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,12 +578,12 @@ bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
578578
unsigned &arity, unsigned &order,
579579
bool &rethrows);
580580

581-
/// Computes the correct linkage for associated functions given the linkage of
581+
/// Computes the correct linkage for an associated function given the linkage of
582582
/// the original function. If the original linkage is not external and
583583
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
584584
/// return hidden linkage.
585-
SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
586-
bool isAssocFnExported);
585+
SILLinkage getAutoDiffAssociatedFunctionLinkage(SILLinkage originalLinkage,
586+
bool isAssocFnExported);
587587

588588
} // end namespace autodiff
589589

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4153,7 +4153,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41534153
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
41544154
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
41554155
SILModule &module, LookupConformanceFn lookupConformance,
4156-
GenericSignature *whereClauseGenericSignature = nullptr);
4156+
CanGenericSignature whereClauseGenericSignature = nullptr);
41574157

41584158
/// Returns a bit vector that specifices which parameters you can
41594159
/// differentiate with respect to for this differentiable function type. (e.g.

lib/AST/AutoDiff.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
8989
return operationName.empty();
9090
}
9191

92-
SILLinkage autodiff::getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
93-
bool isAssocFnExported) {
92+
SILLinkage autodiff::getAutoDiffAssociatedFunctionLinkage(
93+
SILLinkage originalLinkage, bool isAssocFnExported) {
9494
// If the original is defined externally, then the AD pass is just generating
9595
// associated functions for use in the current module and therefore these
9696
// associated functions should not be visible outside the module.

lib/SIL/SILFunctionType.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,18 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
150150
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
151151
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
152152
SILModule &module, LookupConformanceFn lookupConformance,
153-
GenericSignature *whereClauseGenSig) {
153+
CanGenericSignature whereClauseGenSig) {
154154
// JVP: (T...) -> ((R...),
155155
// (T.TangentVector...) -> (R.TangentVector...))
156156
// VJP: (T...) -> ((R...),
157157
// (R.TangentVector...) -> (T.TangentVector...))
158158

159159
auto &ctx = getASTContext();
160160
auto &typeConverter = module.Types;
161-
Lowering::GenericContextScope
162-
genericContextScope(module.Types, getGenericSignature());
163161
if (!whereClauseGenSig)
164162
whereClauseGenSig = getGenericSignature();
163+
Lowering::GenericContextScope genericContextScope(
164+
module.Types, whereClauseGenSig);
165165

166166
// Given a type, returns its formal SIL parameter info.
167167
auto getTangentParameterInfoForOriginalResult = [&](

lib/SILGen/SILGenPoly.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3530,7 +3530,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
35303530

35313531
auto loc = assocFn->getLocation();
35323532
SILGenFunctionBuilder fb(*this);
3533-
auto linkage = autodiff::getAutoDiffFunctionLinkage(
3533+
auto linkage = autodiff::getAutoDiffAssociatedFunctionLinkage(
35343534
original->getLinkage(), /*isAssocFnExported*/ true);
35353535
auto *thunk = fb.getOrCreateFunction(
35363536
loc, name, linkage, targetType, IsBare, IsNotTransparent,

lib/SILOptimizer/IPO/CapturePropagation.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,13 +461,20 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
461461
SILOptFunctionBuilder FuncBuilder(*this);
462462
if (auto *NewFunc = getSpecializedWithDeadParams(FuncBuilder,
463463
PAI, SubstF, PAI->getNumArguments(), GenericSpecialized)) {
464-
rewritePartialApply(PAI, NewFunc);
465-
if (GenericSpecialized.first) {
466-
// Notify the pass manager about the new function.
467-
addFunctionToPassManagerWorklist(GenericSpecialized.first,
468-
GenericSpecialized.second);
464+
// SWIFT_ENABLE_TENSORFLOW
465+
// Add a previously unexercised check to prevent AD crash. Rewrite
466+
// `partial_apply` only if the specialized function is `@convention(thin)`.
467+
// Revert check when `VJPEmitter::visitApplyInst` no longer produces
468+
// argumentless `partial_apply` instructions.
469+
if (NewFunc->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
470+
rewritePartialApply(PAI, NewFunc);
471+
if (GenericSpecialized.first) {
472+
// Notify the pass manager about the new function.
473+
addFunctionToPassManagerWorklist(GenericSpecialized.first,
474+
GenericSpecialized.second);
475+
}
476+
return true;
469477
}
470-
return true;
471478
}
472479

473480
// Second possibility: Are all partially applied arguments constant?

0 commit comments

Comments
 (0)