Skip to content

Commit 24445dd

Browse files
authored
[AutoDiff upstream] Add differentiability witness SILGen. (#30545)
Generate SIL differentiability witnesses from `@differentiable` and `@derivative` declaration attributes. Add SILGen utilities for: - Emiting differentiability witnesses. - Creating derivative function thunks, which are used as entries in differentiability witnesses. When users register a custom derivative function, it is necessary to create a thunk with the expected derivative type computed from the original function's type. This is important for consistent typing and consistent differentiability witness entry mangling. See `SILGenModule::getOrCreateCustomDerivativeThunk` documentation for details. Resolves TF-1138.
1 parent 7c5b4d1 commit 24445dd

File tree

9 files changed

+1166
-4
lines changed

9 files changed

+1166
-4
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,25 @@ class ASTMangler : public Mangler {
153153
Type SelfType,
154154
ModuleDecl *Module);
155155

156+
/// Mangle the derivative function (JVP/VJP) for the given:
157+
/// - Mangled original function name.
158+
/// - Derivative function kind.
159+
/// - Derivative function configuration: parameter/result indices and
160+
/// derivative generic signature.
161+
std::string
162+
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
163+
AutoDiffDerivativeFunctionKind kind,
164+
AutoDiffConfig config);
165+
166+
/// Mangle the linear map (differential/pullback) for the given:
167+
/// - Mangled original function name.
168+
/// - Linear map kind.
169+
/// - Derivative function configuration: parameter/result indices and
170+
/// derivative generic signature.
171+
std::string mangleAutoDiffLinearMapHelper(StringRef name,
172+
AutoDiffLinearMapKind kind,
173+
AutoDiffConfig config);
174+
156175
/// Mangle a SIL differentiability witness key:
157176
/// - Mangled original function name.
158177
/// - Parameter indices.

include/swift/AST/AutoDiff.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/TypeAlignments.h"
2727
#include "swift/Basic/Range.h"
2828
#include "swift/Basic/SourceLoc.h"
29+
#include "llvm/ADT/StringExtras.h"
2930

3031
namespace swift {
3132

@@ -95,6 +96,45 @@ struct DifferentiabilityWitnessFunctionKind {
9596
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
9697
};
9798

99+
/// SIL-level automatic differentiation indices. Consists of:
100+
/// - Parameter indices: indices of parameters to differentiate with respect to.
101+
/// - Result index: index of the result to differentiate from.
102+
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
103+
// `AutoDiffConfig` supports multiple result indices.
104+
struct SILAutoDiffIndices {
105+
/// The index of the dependent result to differentiate from.
106+
unsigned source;
107+
/// The indices for independent parameters to differentiate with respect to.
108+
IndexSubset *parameters;
109+
110+
/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
111+
: source(source), parameters(parameters) {}
112+
113+
bool operator==(const SILAutoDiffIndices &other) const;
114+
115+
bool operator!=(const SILAutoDiffIndices &other) const {
116+
return !(*this == other);
117+
};
118+
119+
/// Returns true if `parameterIndex` is a differentiability parameter index.
120+
bool isWrtParameter(unsigned parameterIndex) const {
121+
return parameterIndex < parameters->getCapacity() &&
122+
parameters->contains(parameterIndex);
123+
}
124+
125+
void print(llvm::raw_ostream &s = llvm::outs()) const;
126+
SWIFT_DEBUG_DUMP;
127+
128+
std::string mangle() const {
129+
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
130+
interleave(
131+
parameters->getIndices(),
132+
[&](unsigned idx) { result += llvm::utostr(idx); },
133+
[&] { result += '_'; });
134+
return result;
135+
}
136+
};
137+
98138
/// Identifies an autodiff derivative function configuration:
99139
/// - Parameter indices.
100140
/// - Result indices.
@@ -110,6 +150,11 @@ struct AutoDiffConfig {
110150
: parameterIndices(parameterIndices), resultIndices(resultIndices),
111151
derivativeGenericSignature(derivativeGenericSignature) {}
112152

153+
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
154+
// TODO(TF-913): This is a temporary shim for incremental removal of
155+
// `SILAutoDiffIndices`. Eventually remove this.
156+
SILAutoDiffIndices getSILAutoDiffIndices() const;
157+
113158
void print(llvm::raw_ostream &s = llvm::outs()) const;
114159
SWIFT_DEBUG_DUMP;
115160
};
@@ -282,6 +327,37 @@ void getFunctionSemanticResultTypes(
282327
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
283328
GenericEnvironment *genericEnv = nullptr);
284329

330+
/// Returns the lowered SIL parameter indices for the given AST parameter
331+
/// indices and `AnyfunctionType`.
332+
///
333+
/// Notable lowering-related changes:
334+
/// - AST tuple parameter types are exploded when lowered to SIL.
335+
/// - AST curried `Self` parameter types become the last parameter when lowered
336+
/// to SIL.
337+
///
338+
/// Examples:
339+
///
340+
/// AST function type: (A, B, C) -> R
341+
/// AST parameter indices: 101, {A, C}
342+
/// Lowered SIL function type: $(A, B, C) -> R
343+
/// Lowered SIL parameter indices: 101
344+
///
345+
/// AST function type: (Self) -> (A, B, C) -> R
346+
/// AST parameter indices: 1010, {Self, B}
347+
/// Lowered SIL function type: $(A, B, C, Self) -> R
348+
/// Lowered SIL parameter indices: 0101
349+
///
350+
/// AST function type: (A, (B, C), D) -> R
351+
/// AST parameter indices: 110, {A, (B, C)}
352+
/// Lowered SIL function type: $(A, B, C, D) -> R
353+
/// Lowered SIL parameter indices: 1110
354+
///
355+
/// Note:
356+
/// - The AST function type must not be curried unless it is a method.
357+
/// Otherwise, the behavior is undefined.
358+
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
359+
AnyFunctionType *functionType);
360+
285361
/// "Constrained" derivative generic signatures require all differentiability
286362
/// parameters to conform to the `Differentiable` protocol.
287363
///

lib/AST/ASTMangler.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,57 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
373373
return finalize();
374374
}
375375

376+
std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
377+
StringRef name, AutoDiffDerivativeFunctionKind kind,
378+
AutoDiffConfig config) {
379+
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
380+
beginManglingWithoutPrefix();
381+
382+
Buffer << "AD__" << name << '_';
383+
switch (kind) {
384+
case AutoDiffDerivativeFunctionKind::JVP:
385+
Buffer << "_jvp_";
386+
break;
387+
case AutoDiffDerivativeFunctionKind::VJP:
388+
Buffer << "_vjp_";
389+
break;
390+
}
391+
Buffer << config.getSILAutoDiffIndices().mangle();
392+
if (config.derivativeGenericSignature) {
393+
Buffer << '_';
394+
appendGenericSignature(config.derivativeGenericSignature);
395+
}
396+
397+
auto result = Storage.str().str();
398+
Storage.clear();
399+
return result;
400+
}
401+
402+
std::string ASTMangler::mangleAutoDiffLinearMapHelper(
403+
StringRef name, AutoDiffLinearMapKind kind, AutoDiffConfig config) {
404+
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
405+
beginManglingWithoutPrefix();
406+
407+
Buffer << "AD__" << name << '_';
408+
switch (kind) {
409+
case AutoDiffLinearMapKind::Differential:
410+
Buffer << "_differential_";
411+
break;
412+
case AutoDiffLinearMapKind::Pullback:
413+
Buffer << "_pullback_";
414+
break;
415+
}
416+
Buffer << config.getSILAutoDiffIndices().mangle();
417+
if (config.derivativeGenericSignature) {
418+
Buffer << '_';
419+
appendGenericSignature(config.derivativeGenericSignature);
420+
}
421+
422+
auto result = Storage.str().str();
423+
Storage.clear();
424+
return result;
425+
}
426+
376427
std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
377428
SILDifferentiabilityWitnessKey key) {
378429
// TODO(TF-20): Make the mangling scheme robust. Support demangling.

lib/AST/AutoDiff.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
4141
}
4242
}
4343

44+
void SILAutoDiffIndices::print(llvm::raw_ostream &s) const {
45+
s << "(source=" << source << " parameters=(";
46+
interleave(
47+
parameters->getIndices(), [&s](unsigned p) { s << p; },
48+
[&s] { s << ' '; });
49+
s << "))";
50+
}
51+
52+
void SILAutoDiffIndices::dump() const {
53+
print(llvm::errs());
54+
llvm::errs() << '\n';
55+
}
56+
57+
SILAutoDiffIndices AutoDiffConfig::getSILAutoDiffIndices() const {
58+
assert(resultIndices->getNumIndices() == 1);
59+
return SILAutoDiffIndices(*resultIndices->begin(), parameterIndices);
60+
}
61+
4462
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
4563
s << "(parameters=";
4664
parameterIndices->print(s);
@@ -138,6 +156,42 @@ void autodiff::getFunctionSemanticResultTypes(
138156
}
139157
}
140158

159+
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
160+
IndexSubset *
161+
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
162+
AnyFunctionType *functionType) {
163+
SmallVector<AnyFunctionType *, 2> curryLevels;
164+
unwrapCurryLevels(functionType, curryLevels);
165+
166+
// Compute the lowered sizes of all AST parameter types.
167+
SmallVector<unsigned, 8> paramLoweredSizes;
168+
unsigned totalLoweredSize = 0;
169+
auto addLoweredParamInfo = [&](Type type) {
170+
unsigned paramLoweredSize = countNumFlattenedElementTypes(type);
171+
paramLoweredSizes.push_back(paramLoweredSize);
172+
totalLoweredSize += paramLoweredSize;
173+
};
174+
for (auto *curryLevel : llvm::reverse(curryLevels))
175+
for (auto &param : curryLevel->getParams())
176+
addLoweredParamInfo(param.getPlainType());
177+
178+
// Build lowered SIL parameter indices by setting the range of bits that
179+
// corresponds to each "set" AST parameter.
180+
llvm::SmallVector<unsigned, 8> loweredSILIndices;
181+
unsigned currentBitIndex = 0;
182+
for (unsigned i : range(parameterIndices->getCapacity())) {
183+
auto paramLoweredSize = paramLoweredSizes[i];
184+
if (parameterIndices->contains(i)) {
185+
auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize);
186+
loweredSILIndices.append(indices.begin(), indices.end());
187+
}
188+
currentBitIndex += paramLoweredSize;
189+
}
190+
191+
return IndexSubset::get(functionType->getASTContext(), totalLoweredSize,
192+
loweredSILIndices);
193+
}
194+
141195
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
142196
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
143197
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,

lib/SILGen/SILGen.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,132 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
751751
LLVM_DEBUG(llvm::dbgs() << "lowered sil:\n";
752752
F->print(llvm::dbgs()));
753753
F->verify();
754+
755+
emitDifferentiabilityWitnessesForFunction(constant, F);
756+
}
757+
758+
void SILGenModule::emitDifferentiabilityWitnessesForFunction(
759+
SILDeclRef constant, SILFunction *F) {
760+
// Visit `@differentiable` amd `@derivative` attributes and generate SIL
761+
// differentiability witnesses.
762+
// Skip if the SILDeclRef is a:
763+
// - Default argument generator function.
764+
// - Thunk.
765+
if (!constant.hasDecl() || !constant.getAbstractFunctionDecl())
766+
return;
767+
if (constant.kind == SILDeclRef::Kind::DefaultArgGenerator ||
768+
constant.isThunk())
769+
return;
770+
auto *AFD = constant.getAbstractFunctionDecl();
771+
auto emitWitnesses = [&](DeclAttributes &Attrs) {
772+
for (auto *diffAttr : Attrs.getAttributes<DifferentiableAttr>()) {
773+
SILFunction *jvp = nullptr;
774+
SILFunction *vjp = nullptr;
775+
if (auto *jvpDecl = diffAttr->getJVPFunction())
776+
jvp = getFunction(SILDeclRef(jvpDecl), ForDefinition);
777+
if (auto *vjpDecl = diffAttr->getVJPFunction())
778+
vjp = getFunction(SILDeclRef(vjpDecl), ForDefinition);
779+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
780+
assert((!F->getLoweredFunctionType()->getSubstGenericSignature() ||
781+
diffAttr->getDerivativeGenericSignature()) &&
782+
"Type-checking should resolve derivative generic signatures for "
783+
"all original SIL functions with generic signatures");
784+
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
785+
diffAttr->getDerivativeGenericSignature());
786+
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
787+
}
788+
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
789+
SILFunction *jvp = nullptr;
790+
SILFunction *vjp = nullptr;
791+
switch (derivAttr->getDerivativeKind()) {
792+
case AutoDiffDerivativeFunctionKind::JVP:
793+
jvp = F;
794+
break;
795+
case AutoDiffDerivativeFunctionKind::VJP:
796+
vjp = F;
797+
break;
798+
}
799+
auto *origAFD = derivAttr->getOriginalFunction();
800+
auto origDeclRef =
801+
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
802+
auto *origFn = getFunction(origDeclRef, NotForDefinition);
803+
auto derivativeGenSig = AFD->getGenericSignature();
804+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
805+
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
806+
derivativeGenSig);
807+
emitDifferentiabilityWitness(origAFD, origFn, config, jvp, vjp,
808+
derivAttr);
809+
}
810+
};
811+
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
812+
if (accessor->isGetter())
813+
emitWitnesses(accessor->getStorage()->getAttrs());
814+
emitWitnesses(AFD->getAttrs());
815+
}
816+
817+
void SILGenModule::emitDifferentiabilityWitness(
818+
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
819+
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp,
820+
const DeclAttribute *attr) {
821+
assert(isa<DifferentiableAttr>(attr) || isa<DerivativeAttr>(attr));
822+
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
823+
auto origSilFnType = originalFunction->getLoweredFunctionType();
824+
auto *silParamIndices =
825+
autodiff::getLoweredParameterIndices(config.parameterIndices, origFnType);
826+
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
827+
// parameters corresponding to captured variables. These parameters do not
828+
// appear in the type of `origFnType`.
829+
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
830+
// take `CaptureInfo` into account.
831+
if (origSilFnType->getNumParameters() > silParamIndices->getCapacity())
832+
silParamIndices = silParamIndices->extendingCapacity(
833+
getASTContext(), origSilFnType->getNumParameters());
834+
835+
// Get or create new SIL differentiability witness.
836+
// Witness already exists when there are two `@derivative` attributes
837+
// (registering JVP and VJP functions) for the same derivative function
838+
// configuration.
839+
// Witness JVP and VJP are set below.
840+
AutoDiffConfig silConfig(silParamIndices, config.resultIndices,
841+
config.derivativeGenericSignature);
842+
SILDifferentiabilityWitnessKey key{originalFunction->getName(), silConfig};
843+
auto *diffWitness = M.lookUpDifferentiabilityWitness(key);
844+
if (!diffWitness) {
845+
// Strip external from linkage of original function.
846+
// Necessary for Clang-imported functions, which have external linkage.
847+
auto linkage = stripExternalFromLinkage(originalFunction->getLinkage());
848+
diffWitness = SILDifferentiabilityWitness::createDefinition(
849+
M, linkage, originalFunction, silConfig.parameterIndices,
850+
silConfig.resultIndices, config.derivativeGenericSignature,
851+
/*jvp*/ nullptr, /*vjp*/ nullptr,
852+
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
853+
attr);
854+
}
855+
856+
// Set derivative function in differentiability witness.
857+
auto setDerivativeInDifferentiabilityWitness =
858+
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
859+
auto derivativeThunk = getOrCreateCustomDerivativeThunk(
860+
derivative, originalFunction, silConfig, kind);
861+
// Check for existing same derivative.
862+
// TODO(TF-835): Remove condition below and simplify assertion to
863+
// `!diffWitness->getDerivative(kind)` after `@derivative` attribute
864+
// type-checking no longer generates implicit `@differentiable`
865+
// attributes.
866+
auto *existingDerivative = diffWitness->getDerivative(kind);
867+
if (existingDerivative && existingDerivative == derivativeThunk)
868+
return;
869+
assert(!existingDerivative &&
870+
"SIL differentiability witness already has a different existing "
871+
"derivative");
872+
diffWitness->setDerivative(kind, derivativeThunk);
873+
};
874+
if (jvp)
875+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP,
876+
jvp);
877+
if (vjp)
878+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP,
879+
vjp);
754880
}
755881

756882
void SILGenModule::

0 commit comments

Comments
 (0)