Skip to content

Commit a2d0e3f

Browse files
committed
Address review comments.
- Add doc comments to `getAutoDiffFunctionLinkage`. - Optimize "curry level unwrapping" in `AnyFunctionType::getAutoDiffAssociatedFunctionType`.
1 parent f8f4e89 commit a2d0e3f

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,16 +572,20 @@ getOffsetForAutoDiffAssociatedFunction(unsigned order,
572572
unsigned
573573
getNumAutoDiffAssociatedFunctions(unsigned differentiationOrder);
574574

575-
// Retrieve config from the function name of a variant of
576-
// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
577-
// Returns true if the function name is parsed successfully.
575+
/// Retrieve config from the function name of a variant of
576+
/// `Builtin.autodiffApply`, e.g. `Builtin.autodiffApply_jvp_arity2_order1`.
577+
/// Returns true if the function name is parsed successfully.
578578
bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
579579
AutoDiffAssociatedFunctionKind &kind,
580580
unsigned &arity, unsigned &order,
581581
bool &rethrows);
582582

583+
/// Computes the correct linkage for associated functions given the linkage of
584+
/// the original function. If the original linkage is not external and
585+
/// `isAssocFnExported` is true, use the original function's linkage. Otherwise,
586+
/// return hidden linkage.
583587
SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
584-
bool isExported);
588+
bool isAssocFnExported);
585589

586590
} // end namespace autodiff
587591

lib/AST/AutoDiff.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,22 @@ bool autodiff::getBuiltinAutoDiffApplyConfig(
9090
}
9191

9292
SILLinkage autodiff::getAutoDiffFunctionLinkage(SILLinkage originalLinkage,
93-
bool isExported) {
93+
bool isAssocFnExported) {
9494
// If the original is defined externally, then the AD pass is just generating
9595
// associated functions for use in the current module and therefore these
9696
// associated functions should not be visible outside the module.
9797
if (isAvailableExternally(originalLinkage))
9898
return SILLinkage::Hidden;
9999

100100
// If the original is public, then external modules may need to link the
101-
// associated function. Make the associated function public unless
102-
// differentiation is not explicitly requested.
101+
// associated function. Return the linkage of the original function, unless
102+
// the associated function is not exported (i.e. differentiation is not
103+
// explicitly requested via a `[differentiable]` attribute on the original
104+
// function).
103105
if (originalLinkage == SILLinkage::Public ||
104106
originalLinkage == SILLinkage::PublicNonABI ||
105107
originalLinkage == SILLinkage::Shared)
106-
return isExported ? originalLinkage : SILLinkage::Hidden;
108+
return isAssocFnExported ? originalLinkage : SILLinkage::Hidden;
107109

108110
// Otherwise, the original function is defined and used only in the current
109111
// module, so external modules will never try to access the associated

lib/AST/Type.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4463,7 +4463,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44634463
indices->getSubsetParameterTypes(this, wrtParamTypes);
44644464

44654465
// Unwrap curry levels. At most, two parameter lists are necessary, for
4466-
// curried method types.
4466+
// curried method types with a `(Self)` parameter list.
44674467
SmallVector<AnyFunctionType *, 2> curryLevels;
44684468
auto *currentLevel = eraseDynamicSelfType()->castTo<AnyFunctionType>();
44694469
for (unsigned i : range(2)) {
@@ -4576,7 +4576,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
45764576
AnyFunctionType *
45774577
AnyFunctionType::getAutoDiffOriginalFunctionType() {
45784578
// Unwrap curry levels. At most, two parameter lists are necessary, for
4579-
// curried method types.
4579+
// curried method types with a `(Self)` parameter list.
45804580
SmallVector<AnyFunctionType *, 2> curryLevels;
45814581
auto *currentLevel = this;
45824582
for (unsigned i : range(2)) {

lib/SILGen/SILGenPoly.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3531,7 +3531,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
35313531
auto loc = assocFn->getLocation();
35323532
SILGenFunctionBuilder fb(*this);
35333533
auto linkage = autodiff::getAutoDiffFunctionLinkage(
3534-
original->getLinkage(), /*isExported*/ true);
3534+
original->getLinkage(), /*isAssocFnExported*/ true);
35353535
auto *thunk = fb.getOrCreateFunction(
35363536
loc, name, linkage, targetType, IsBare, IsNotTransparent,
35373537
assocFn->isSerialized(), assocFn->isDynamicallyReplaceable(),

0 commit comments

Comments
 (0)