Skip to content

Commit 723b2d2

Browse files
authored
[AutoDiff upstream] Add derivative function witness/vtable entry SILGen. (#30569)
`@differentiable` attribute on protocol requirements and non-final class members now produces derivative function entries in witness tables and vtables. This enables `witness_method` and `class_method` differentiation. Existing type-checking rules: - Witness declarations of `@differentiable` protocol requirements must have a `@differentiable` attribute with the same configuration (or a configuration with superset parameter indices). - Witness table derivative function entries are SILGen'd for `@differentiable` witness declarations. - Class vtable derivative function entries are SILGen'd for non-final `@differentiable` class members. - These derivative entries can be overridden or inherited, just like other vtable entries. Resolves TF-1212.
1 parent e9a3c7f commit 723b2d2

17 files changed

+631
-87
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,8 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
104104
void Profile(llvm::FoldingSetNodeID &ID) {
105105
ID.AddInteger(kind);
106106
ID.AddPointer(parameterIndices);
107-
CanGenericSignature derivativeCanGenSig;
108-
if (derivativeGenericSignature)
109-
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
107+
auto derivativeCanGenSig =
108+
derivativeGenericSignature.getCanonicalSignature();
110109
ID.AddPointer(derivativeCanGenSig.getPointer());
111110
}
112111
};

include/swift/SIL/SILDeclRef.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,13 @@ struct SILDeclRef {
324324
return declRef;
325325
}
326326

327+
/// Returns this `SILDeclRef` replacing `loc` with `decl`.
328+
SILDeclRef withDecl(ValueDecl *decl) const {
329+
SILDeclRef result = *this;
330+
result.loc = decl;
331+
return result;
332+
}
333+
327334
/// True if the decl ref references a thunk from a natively foreign
328335
/// declaration to Swift calling convention.
329336
bool isForeignToNativeThunk() const;

include/swift/SIL/SILVTableVisitor.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,24 @@ template <class T> class SILVTableVisitor {
8686
void maybeAddMethod(FuncDecl *fd) {
8787
assert(!fd->hasClangNode());
8888

89-
maybeAddEntry(SILDeclRef(fd, SILDeclRef::Kind::Func));
89+
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
90+
maybeAddEntry(constant);
91+
92+
for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
93+
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
94+
AutoDiffDerivativeFunctionIdentifier::get(
95+
AutoDiffDerivativeFunctionKind::JVP,
96+
diffAttr->getParameterIndices(),
97+
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
98+
maybeAddEntry(jvpConstant);
99+
100+
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
101+
AutoDiffDerivativeFunctionIdentifier::get(
102+
AutoDiffDerivativeFunctionKind::VJP,
103+
diffAttr->getParameterIndices(),
104+
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
105+
maybeAddEntry(vjpConstant);
106+
}
90107
}
91108

92109
void maybeAddConstructor(ConstructorDecl *cd) {
@@ -96,7 +113,24 @@ template <class T> class SILVTableVisitor {
96113
// The initializing entry point for designated initializers is only
97114
// necessary for super.init chaining, which is sufficiently constrained
98115
// to never need dynamic dispatch.
99-
maybeAddEntry(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
116+
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
117+
maybeAddEntry(constant);
118+
119+
for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
120+
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
121+
AutoDiffDerivativeFunctionIdentifier::get(
122+
AutoDiffDerivativeFunctionKind::JVP,
123+
diffAttr->getParameterIndices(),
124+
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
125+
maybeAddEntry(jvpConstant);
126+
127+
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
128+
AutoDiffDerivativeFunctionIdentifier::get(
129+
AutoDiffDerivativeFunctionKind::VJP,
130+
diffAttr->getParameterIndices(),
131+
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
132+
maybeAddEntry(vjpConstant);
133+
}
100134
}
101135

102136
void maybeAddAccessors(AbstractStorageDecl *asd) {

include/swift/SIL/SILWitnessVisitor.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,19 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
122122

123123
void visitAbstractStorageDecl(AbstractStorageDecl *sd) {
124124
sd->visitOpaqueAccessors([&](AccessorDecl *accessor) {
125-
if (SILDeclRef::requiresNewWitnessTableEntry(accessor))
125+
if (SILDeclRef::requiresNewWitnessTableEntry(accessor)) {
126126
asDerived().addMethod(SILDeclRef(accessor, SILDeclRef::Kind::Func));
127+
addAutoDiffDerivativeMethodsIfRequired(accessor,
128+
SILDeclRef::Kind::Func);
129+
}
127130
});
128131
}
129132

130133
void visitConstructorDecl(ConstructorDecl *cd) {
131-
if (SILDeclRef::requiresNewWitnessTableEntry(cd))
134+
if (SILDeclRef::requiresNewWitnessTableEntry(cd)) {
132135
asDerived().addMethod(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
136+
addAutoDiffDerivativeMethodsIfRequired(cd, SILDeclRef::Kind::Allocator);
137+
}
133138
}
134139

135140
void visitAccessorDecl(AccessorDecl *func) {
@@ -138,8 +143,10 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
138143

139144
void visitFuncDecl(FuncDecl *func) {
140145
assert(!isa<AccessorDecl>(func));
141-
if (SILDeclRef::requiresNewWitnessTableEntry(func))
146+
if (SILDeclRef::requiresNewWitnessTableEntry(func)) {
142147
asDerived().addMethod(SILDeclRef(func, SILDeclRef::Kind::Func));
148+
addAutoDiffDerivativeMethodsIfRequired(func, SILDeclRef::Kind::Func);
149+
}
143150
}
144151

145152
void visitMissingMemberDecl(MissingMemberDecl *placeholder) {
@@ -166,6 +173,26 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
166173
void visitPoundDiagnosticDecl(PoundDiagnosticDecl *pdd) {
167174
// We don't care about diagnostics at this stage.
168175
}
176+
177+
private:
178+
void addAutoDiffDerivativeMethodsIfRequired(AbstractFunctionDecl *AFD,
179+
SILDeclRef::Kind kind) {
180+
SILDeclRef declRef(AFD, kind);
181+
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
182+
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
183+
AutoDiffDerivativeFunctionIdentifier::get(
184+
AutoDiffDerivativeFunctionKind::JVP,
185+
diffAttr->getParameterIndices(),
186+
diffAttr->getDerivativeGenericSignature(),
187+
AFD->getASTContext())));
188+
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
189+
AutoDiffDerivativeFunctionIdentifier::get(
190+
AutoDiffDerivativeFunctionKind::VJP,
191+
diffAttr->getParameterIndices(),
192+
diffAttr->getDerivativeGenericSignature(),
193+
AFD->getASTContext())));
194+
}
195+
}
169196
};
170197

171198
} // end namespace swift

lib/IRGen/GenDiffWitness.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,28 @@ void IRGenModule::emitSILDifferentiabilityWitness(
3939
ConstantInitBuilder builder(*this);
4040
auto diffWitnessContents = builder.beginStruct();
4141

42+
// TODO(TF-1211): Uncomment assertions after upstreaming differentiation
43+
// transform.
44+
// The mandatory differentiation transform canonicalizes differentiability
45+
// witnesses and ensures that JVPs/VJPs are populated.
46+
/*
4247
assert(dw->getJVP() &&
4348
"Differentiability witness definition should have JVP");
4449
assert(dw->getVJP() &&
4550
"Differentiability witness definition should have VJP");
46-
4751
diffWitnessContents.addBitCast(
4852
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
4953
diffWitnessContents.addBitCast(
5054
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
55+
*/
56+
llvm::Constant *jvpValue = llvm::UndefValue::get(Int8PtrTy);
57+
llvm::Constant *vjpValue = llvm::UndefValue::get(Int8PtrTy);
58+
if (auto *jvpFn = dw->getJVP())
59+
jvpValue = getAddrOfSILFunction(dw->getJVP(), NotForDefinition);
60+
if (auto *vjpFn = dw->getJVP())
61+
vjpValue = getAddrOfSILFunction(dw->getVJP(), NotForDefinition);
62+
diffWitnessContents.addBitCast(jvpValue, Int8PtrTy);
63+
diffWitnessContents.addBitCast(vjpValue, Int8PtrTy);
5164

5265
getAddrOfDifferentiabilityWitness(
5366
dw, diffWitnessContents.finishAndCreateFuture());

lib/IRGen/GenKeyPath.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -992,8 +992,7 @@ emitKeyPathComponent(IRGenModule &IGM,
992992
auto methodProto = cast<ProtocolDecl>(dc);
993993
auto &protoInfo = IGM.getProtocolInfo(methodProto,
994994
ProtocolInfoKind::Full);
995-
auto index = protoInfo.getFunctionIndex(
996-
cast<AbstractFunctionDecl>(declRef.getDecl()));
995+
auto index = protoInfo.getFunctionIndex(declRef);
997996
idValue = llvm::ConstantInt::get(IGM.SizeTy, -index.getValue());
998997
idResolution = KeyPathComponentHeader::Resolved;
999998
}

lib/IRGen/GenProto.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -792,20 +792,19 @@ namespace {
792792
}
793793

794794
void addMethod(SILDeclRef func) {
795-
auto decl = cast<AbstractFunctionDecl>(func.getDecl());
796795
// If this assert needs to be changed, be sure to also change
797796
// ProtocolDescriptorBuilder::getRequirementInfo.
798-
assert((isa<ConstructorDecl>(decl)
799-
? (func.kind == SILDeclRef::Kind::Allocator)
800-
: (func.kind == SILDeclRef::Kind::Func))
801-
&& "unexpected kind for protocol witness declaration ref");
802-
Entries.push_back(WitnessTableEntry::forFunction(decl));
797+
assert((isa<ConstructorDecl>(func.getDecl())
798+
? (func.kind == SILDeclRef::Kind::Allocator)
799+
: (func.kind == SILDeclRef::Kind::Func)) &&
800+
"unexpected kind for protocol witness declaration ref");
801+
Entries.push_back(WitnessTableEntry::forFunction(func));
803802
}
804803

805804
void addPlaceholder(MissingMemberDecl *placeholder) {
806805
for (auto i : range(placeholder->getNumberOfVTableEntries())) {
807806
(void)i;
808-
Entries.push_back(WitnessTableEntry());
807+
Entries.push_back(WitnessTableEntry::forPlaceholder());
809808
}
810809
}
811810

@@ -1318,8 +1317,7 @@ class AccessorConformanceInfo : public ConformanceInfo {
13181317
&& "sil witness table does not match protocol");
13191318
assert(entry.getMethodWitness().Requirement == requirement
13201319
&& "sil witness table does not match protocol");
1321-
auto piIndex =
1322-
PI.getFunctionIndex(cast<AbstractFunctionDecl>(requirement.getDecl()));
1320+
auto piIndex = PI.getFunctionIndex(requirement);
13231321
assert((size_t)piIndex.getValue() ==
13241322
Table.size() - WitnessTableFirstRequirementOffset &&
13251323
"offset doesn't match ProtocolInfo layout");
@@ -3277,7 +3275,7 @@ FunctionPointer irgen::emitWitnessMethodValue(IRGenFunction &IGF,
32773275

32783276
// Find the witness we're interested in.
32793277
auto &fnProtoInfo = IGF.IGM.getProtocolInfo(proto, ProtocolInfoKind::Full);
3280-
auto index = fnProtoInfo.getFunctionIndex(fn);
3278+
auto index = fnProtoInfo.getFunctionIndex(member);
32813279
llvm::Value *slot;
32823280
llvm::Value *witnessFnPtr =
32833281
emitInvariantLoadOfOpaqueWitness(IGF, wtable,

0 commit comments

Comments
 (0)