Skip to content

Commit 4d41c78

Browse files
committed
[AutoDiff] SILGen differentiability witnesses.
Generate SIL differentiability witnesses from AST `@differentiable` and `@differentiating` attributes.
1 parent 811a6bd commit 4d41c78

File tree

9 files changed

+243
-74
lines changed

9 files changed

+243
-74
lines changed

include/swift/AST/Attr.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,8 @@ class DifferentiatingAttr final
16821682
unsigned NumParsedParameters = 0;
16831683
/// The differentiation parameters' indices, resolved by the type checker.
16841684
IndexSubset *ParameterIndices = nullptr;
1685+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1686+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
16851687

16861688
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16871689
SourceLoc atLoc, SourceRange baseRange,
@@ -1711,6 +1713,12 @@ class DifferentiatingAttr final
17111713
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
17121714
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
17131715

1716+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
1717+
assert(Kind && "Derivative function kind has not yet been resolved");
1718+
return *Kind;
1719+
}
1720+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1721+
17141722
/// The parsed differentiation parameters, i.e. the list of parameters
17151723
/// specified in 'wrt:'.
17161724
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {

include/swift/AST/AutoDiff.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,27 @@ template<> struct DenseMapInfo<AutoDiffConfig> {
413413
}
414414
};
415415

416+
template<> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
417+
static AutoDiffDerivativeFunctionKind getEmptyKey() {
418+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
419+
DenseMapInfo<unsigned>::getEmptyKey());
420+
}
421+
422+
static AutoDiffDerivativeFunctionKind getTombstoneKey() {
423+
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(
424+
DenseMapInfo<unsigned>::getTombstoneKey());
425+
}
426+
427+
static unsigned getHashValue(const AutoDiffDerivativeFunctionKind &Val) {
428+
return DenseMapInfo<unsigned>::getHashValue(Val);
429+
}
430+
431+
static bool isEqual(const AutoDiffDerivativeFunctionKind &LHS,
432+
const AutoDiffDerivativeFunctionKind &RHS) {
433+
return LHS == RHS;
434+
}
435+
};
436+
416437
template<> struct DenseMapInfo<SILAutoDiffIndices> {
417438
static SILAutoDiffIndices getEmptyKey() {
418439
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };

include/swift/SIL/SILModule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,13 @@ class SILModule {
609609
/// hierarchy of \p Class.
610610
SILFunction *lookUpFunctionInVTable(ClassDecl *Class, SILDeclRef Member);
611611

612+
// SWIFT_ENABLE_TENSORFLOW
613+
/// Look up the differentiability witness corresponding to the given key.
614+
SILDifferentiabilityWitness *
615+
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key,
616+
bool deserializeLazily=true);
617+
// SWIFT_ENABLE_TENSORFLOW END
618+
612619
// Given a protocol, attempt to create a default witness table declaration
613620
// for it.
614621
SILDefaultWitnessTable *

lib/SIL/SILModule.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,18 @@ SILModule::lookUpDefaultWitnessTable(const ProtocolDecl *Protocol,
251251
return found->second;
252252
}
253253

254+
// SWIFT_ENABLE_TENSORFLOW
255+
SILDifferentiabilityWitness *
256+
SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key,
257+
bool deserializeLazily) {
258+
auto found = DifferentiabilityWitnessMap.find(key);
259+
if (found != DifferentiabilityWitnessMap.end())
260+
return found->second;
261+
if (deserializeLazily)
262+
return getSILLoader()->lookupDifferentiabilityWitness(key);
263+
return nullptr;
264+
}
265+
254266
SILDefaultWitnessTable *
255267
SILModule::createDefaultWitnessTableDeclaration(const ProtocolDecl *Protocol,
256268
SILLinkage Linkage) {

lib/SILGen/SILGen.cpp

Lines changed: 127 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -752,87 +752,140 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
752752
F->print(llvm::dbgs()));
753753

754754
// SWIFT_ENABLE_TENSORFLOW
755-
// Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods.
756-
if (constant.hasDecl() && constant.getAbstractFunctionDecl()) {
755+
// Visit `@differentiable` and `@differentiating` attributes and generate SIL
756+
// differentiability witnesses.
757+
// Do not visit default argument generator functions.
758+
if (constant.hasDecl() && constant.getAbstractFunctionDecl() &&
759+
constant.kind != SILDeclRef::Kind::DefaultArgGenerator) {
757760
auto *AFD = constant.getAbstractFunctionDecl();
758-
auto origFnType = AFD->getInterfaceType()->castTo<AnyFunctionType>();
759-
auto origSilFnType = F->getLoweredFunctionType();
760-
// Jointly iterate over AST `@differentiable` attributes and SIL
761-
// `[differentiable]` attributes.
762-
auto diffAttrs = AFD->getAttrs().getAttributes<DifferentiableAttr>();
763-
auto silDiffAttrs = F->getDifferentiableAttrs();
764-
for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) {
765-
auto *diffAttr = const_cast<DifferentiableAttr *>(std::get<0>(pair));
766-
auto *silDiffAttr = std::get<1>(pair);
767-
// Compute lowered parameter indices.
768-
auto *paramIndices = diffAttr->getParameterIndices();
769-
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
770-
paramIndices, origFnType);
771-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
772-
assert(silDiffAttr->getIndices() == indices &&
773-
"Expected matching @differentiable and [differentiable] indices");
774-
775-
auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule());
776-
auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType(
777-
indices.parameters, indices.source,
778-
AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance);
779-
auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType(
780-
indices.parameters, indices.source,
781-
AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance);
782-
783-
// Self reordering is necessary if wrt at least two parameters, including
784-
// self.
785-
auto shouldReorderSelf = [&]() {
786-
if (!F->hasSelfParam())
787-
return false;
788-
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
789-
if (!indices.isWrtParameter(selfParamIndex))
790-
return false;
791-
return indices.parameters->getNumIndices() > 1;
792-
};
793-
bool reorderSelf = shouldReorderSelf();
794-
795-
// Thunk JVP method, if it is defined.
796-
if (auto *jvpDecl = diffAttr->getJVPFunction()) {
797-
SILFunction *jvpThunk;
798-
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
799-
if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) {
800-
jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
801-
F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP,
802-
reorderSelf);
803-
} else {
804-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
805-
AutoDiffDerivativeFunctionKind::JVP,
806-
diffAttr->getParameterIndices(), AFD->getASTContext());
807-
jvpThunk = getOrCreateAutoDiffThunk(
808-
constant.asAutoDiffDerivativeFunction(id), jvpFn,
809-
expectedJVPType);
810-
}
811-
silDiffAttr->setJVPName(jvpThunk->getName());
812-
}
813-
// Thunk VJP method, if it is defined.
814-
if (auto *vjpDecl = diffAttr->getVJPFunction()) {
815-
SILFunction *vjpThunk;
816-
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
817-
if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) {
818-
vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
819-
F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP,
820-
reorderSelf);
821-
} else {
822-
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
823-
AutoDiffDerivativeFunctionKind::VJP,
824-
diffAttr->getParameterIndices(), AFD->getASTContext());
825-
vjpThunk = getOrCreateAutoDiffThunk(
826-
constant.asAutoDiffDerivativeFunction(id), vjpFn,
827-
expectedVJPType);
828-
}
829-
silDiffAttr->setVJPName(vjpThunk->getName());
761+
// Visit all `@differentiable` attributes.
762+
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
763+
SILFunction *jvp = nullptr;
764+
SILFunction *vjp = nullptr;
765+
if (auto *jvpDecl = diffAttr->getJVPFunction())
766+
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
767+
if (auto *vjpDecl = diffAttr->getVJPFunction())
768+
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
769+
emitDifferentiabilityWitness(AFD, F, diffAttr->getParameterIndices(), jvp,
770+
vjp);
771+
}
772+
// Visit all `@differentiating` attributes.
773+
for (auto *diffAttr :
774+
AFD->getAttrs().getAttributes<DifferentiatingAttr>()) {
775+
auto *origAFD = diffAttr->getOriginalFunction();
776+
auto *origFn = getFunction(SILDeclRef(origAFD), NotForDefinition);
777+
SILFunction *jvp = nullptr;
778+
SILFunction *vjp = nullptr;
779+
switch (diffAttr->getDerivativeKind()) {
780+
case AutoDiffDerivativeFunctionKind::JVP:
781+
jvp = F;
782+
break;
783+
case AutoDiffDerivativeFunctionKind::VJP:
784+
vjp = F;
785+
break;
830786
}
787+
emitDifferentiabilityWitness(origAFD, origFn,
788+
diffAttr->getParameterIndices(), jvp, vjp);
831789
}
832790
}
833791
F->verify();
834792
}
835793

794+
void SILGenModule::emitDifferentiabilityWitness(
795+
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
796+
IndexSubset *parameterIndices, SILFunction *jvp, SILFunction *vjp) {
797+
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
798+
auto origSilFnType = originalFunction->getLoweredFunctionType();
799+
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
800+
parameterIndices, origFnType);
801+
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
802+
// parameters corresponding to captured variables. These parameters do not
803+
// appear in the type of `origFnType`.
804+
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
805+
// take `CaptureInfo` into account.
806+
if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity())
807+
loweredParamIndices = loweredParamIndices->extendingCapacity(
808+
getASTContext(), origSilFnType->getNumParameters());
809+
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
810+
811+
// Self reordering thunk is necessary if wrt at least two parameters,
812+
// including self.
813+
auto shouldReorderSelf = [&]() {
814+
if (!originalFunction->hasSelfParam())
815+
return false;
816+
auto selfParamIndex = origSilFnType->getNumParameters() - 1;
817+
if (!indices.isWrtParameter(selfParamIndex))
818+
return false;
819+
return indices.parameters->getNumIndices() > 1;
820+
};
821+
bool reorderSelf = shouldReorderSelf();
822+
823+
// Get or create differentiability witness.
824+
CanGenericSignature derivativeGenSig;
825+
if (jvp && vjp)
826+
assert(jvp->getLoweredFunctionType()->getGenericSignature() ==
827+
vjp->getLoweredFunctionType()->getGenericSignature() &&
828+
"JVP and VJP generic signatures must match");
829+
if (jvp)
830+
derivativeGenSig = jvp->getLoweredFunctionType()->getGenericSignature();
831+
if (vjp)
832+
derivativeGenSig = vjp->getLoweredFunctionType()->getGenericSignature();
833+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
834+
AutoDiffConfig config{loweredParamIndices, resultIndices,
835+
derivativeGenSig};
836+
auto key = std::make_pair(originalFunction->getName(), config);
837+
SILDifferentiabilityWitness *diffWitness = nullptr;
838+
if (auto *foundWitness = M.lookUpDifferentiabilityWitness(
839+
key, /*deserializeLazily*/ false)) {
840+
diffWitness = foundWitness;
841+
} else {
842+
// Create new SIL differentiability witness.
843+
diffWitness = SILDifferentiabilityWitness::create(
844+
M, originalFunction->getLinkage(), originalFunction,
845+
loweredParamIndices, resultIndices, derivativeGenSig, /*jvp*/ nullptr,
846+
/*vjp*/ nullptr, /*isSerialized*/ true);
847+
}
848+
849+
// Set derivative function in differentiability witness.
850+
auto setDerivativeInDifferentiabilityWitness =
851+
[&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) {
852+
auto expectedDerivativeType =
853+
origSilFnType->getAutoDiffDerivativeFunctionType(
854+
indices.parameters, indices.source, kind, Types,
855+
LookUpConformanceInModule(M.getSwiftModule()));
856+
// Thunk derivative function.
857+
SILFunction *derivativeThunk;
858+
if (reorderSelf ||
859+
derivative->getLoweredFunctionType() != expectedDerivativeType) {
860+
derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
861+
originalFunction, indices, derivative, kind, reorderSelf);
862+
} else {
863+
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
864+
kind, parameterIndices, getASTContext());
865+
derivativeThunk = getOrCreateAutoDiffThunk(
866+
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
867+
expectedDerivativeType);
868+
}
869+
// Check for existing same derivative.
870+
// TODO(TF-898): Remove condition below and simplify assertion to
871+
// `!diffWitness->getDerivative(kind)` after `@differentiating` attribute
872+
// type-checking no longer generates implicit `@differentiable` attributes.
873+
auto *existingDerivative = diffWitness->getDerivative(kind);
874+
if (existingDerivative && existingDerivative == derivativeThunk)
875+
return;
876+
assert(!existingDerivative &&
877+
"SIL differentiability witness already has a different existing "
878+
"derivative");
879+
diffWitness->setDerivative(kind, derivativeThunk);
880+
};
881+
if (jvp)
882+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP,
883+
jvp);
884+
if (vjp)
885+
setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP,
886+
vjp);
887+
}
888+
836889
void SILGenModule::
837890
emitMarkFunctionEscapeForTopLevelCodeGlobals(SILLocation loc,
838891
const CaptureInfo &captureInfo) {

lib/SILGen/SILGen.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
318318
/// Emit the self-conformance witness table for a protocol.
319319
void emitSelfConformanceWitnessTable(ProtocolDecl *protocol);
320320

321+
// SWIFT_ENABLE_TENSORFLOW
322+
/// Emit the differentiability witness for the given original function
323+
/// declaration and SIL function, parameter indices, and JVP and VJP
324+
/// functions (null if undefined).
325+
void emitDifferentiabilityWitness(AbstractFunctionDecl *originalAFD,
326+
SILFunction *originalFunction,
327+
IndexSubset *parameterIndices,
328+
SILFunction *jvp, SILFunction *vjp);
329+
// SWIFT_ENABLE_TENSORFLOW END
330+
321331
/// Emit the lazy initializer function for a global pattern binding
322332
/// declaration.
323333
SILFunction *emitLazyGlobalInitializer(StringRef funcName,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3610,6 +3610,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
36103610
attr->setInvalid();
36113611
return;
36123612
}
3613+
attr->setDerivativeKind(kind);
36133614
// `value: R` result tuple element must conform to `Differentiable`.
36143615
auto diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
36153616
auto valueResultType = valueResultElt.getType();

lib/Serialization/DeserializeSIL.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2972,6 +2972,9 @@ void SILDeserializer::readWitnessTableEntries(
29722972
// Another record means the end of this WitnessTable.
29732973
while (kind != SIL_WITNESS_TABLE &&
29742974
kind != SIL_DEFAULT_WITNESS_TABLE &&
2975+
// SWIFT_ENABLE_TENSORFLOW
2976+
kind != SIL_DIFFERENTIABILITY_WITNESS &&
2977+
// SWIFT_ENABLE_TENSORFLOW END
29752978
kind != SIL_FUNCTION) {
29762979
if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) {
29772980
witnessEntries.push_back(SILDefaultWitnessTable::Entry());
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %target-swift-frontend -emit-silgen %s | %target-sil-opt | %FileCheck %s
2+
3+
// Test SIL differentiability witness SIL generation.
4+
5+
// Test public non-generic function.
6+
// SIL differentiability witness:
7+
// - Has public linkage (implicit).
8+
// - Has no `where` clause.
9+
10+
public func foo(_ x: Float) -> Float { x }
11+
12+
@differentiating(foo)
13+
public func foo_jvp(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
14+
(x, { $0 })
15+
}
16+
17+
@differentiating(foo)
18+
public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
19+
(x, { $0 })
20+
}
21+
22+
// CHECK-LABEL: // differentiability witness for foo(_:)
23+
// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3fooyS2fF : $@convention(thin) (Float) -> Float {
24+
// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
25+
// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen3fooyS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
26+
// CHECK-NEXT: }
27+
28+
// Test internal generic function.
29+
// SIL differentiability witness:
30+
// - Has hidden linkage.
31+
// - Has `where` clause.
32+
33+
@differentiable(where T: Differentiable)
34+
func generic<T>(_ x: T, _ y: Float) -> T { x }
35+
36+
@differentiating(generic)
37+
func generic_jvp<T: Differentiable>(_ x: T, _ y: Float) -> (
38+
value: T, differential: (T.TangentVector, Float) -> T.TangentVector
39+
) {
40+
(x, { dx, dy in dx })
41+
}
42+
43+
@differentiating(generic)
44+
func generic_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (
45+
value: T, pullback: (T.TangentVector) -> (T.TangentVector, Float)
46+
) {
47+
(x, { ($0, .zero) })
48+
}
49+
50+
// CHECK-LABEL: // differentiability witness for generic<A>(_:_:)
51+
// CHECK-NEXT: sil_differentiability_witness hidden [serialized] [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @$s36sil_differentiability_witness_silgen7genericyxx_SftlF : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
52+
// CHECK-NEXT: jvp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
53+
// CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen7genericyxx_SftlF__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
54+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)