Skip to content

[AutoDiff upstream] Add derivative function witness/vtable entry SILGen. #30569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
void Profile(llvm::FoldingSetNodeID &ID) {
ID.AddInteger(kind);
ID.AddPointer(parameterIndices);
CanGenericSignature derivativeCanGenSig;
if (derivativeGenericSignature)
derivativeCanGenSig = derivativeGenericSignature->getCanonicalSignature();
auto derivativeCanGenSig =
derivativeGenericSignature.getCanonicalSignature();
ID.AddPointer(derivativeCanGenSig.getPointer());
}
};
Expand Down
7 changes: 7 additions & 0 deletions include/swift/SIL/SILDeclRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,13 @@ struct SILDeclRef {
return declRef;
}

/// Returns this `SILDeclRef` replacing `loc` with `decl`.
SILDeclRef withDecl(ValueDecl *decl) const {
SILDeclRef result = *this;
result.loc = decl;
return result;
}

/// True if the decl ref references a thunk from a natively foreign
/// declaration to Swift calling convention.
bool isForeignToNativeThunk() const;
Expand Down
38 changes: 36 additions & 2 deletions include/swift/SIL/SILVTableVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,24 @@ template <class T> class SILVTableVisitor {
void maybeAddMethod(FuncDecl *fd) {
assert(!fd->hasClangNode());

maybeAddEntry(SILDeclRef(fd, SILDeclRef::Kind::Func));
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
maybeAddEntry(constant);

for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(vjpConstant);
}
}

void maybeAddConstructor(ConstructorDecl *cd) {
Expand All @@ -96,7 +113,24 @@ template <class T> class SILVTableVisitor {
// The initializing entry point for designated initializers is only
// necessary for super.init chaining, which is sufficiently constrained
// to never need dynamic dispatch.
maybeAddEntry(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
maybeAddEntry(constant);

for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(vjpConstant);
}
}

void maybeAddAccessors(AbstractStorageDecl *asd) {
Expand Down
33 changes: 30 additions & 3 deletions include/swift/SIL/SILWitnessVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,19 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {

void visitAbstractStorageDecl(AbstractStorageDecl *sd) {
sd->visitOpaqueAccessors([&](AccessorDecl *accessor) {
if (SILDeclRef::requiresNewWitnessTableEntry(accessor))
if (SILDeclRef::requiresNewWitnessTableEntry(accessor)) {
asDerived().addMethod(SILDeclRef(accessor, SILDeclRef::Kind::Func));
addAutoDiffDerivativeMethodsIfRequired(accessor,
SILDeclRef::Kind::Func);
}
});
}

void visitConstructorDecl(ConstructorDecl *cd) {
if (SILDeclRef::requiresNewWitnessTableEntry(cd))
if (SILDeclRef::requiresNewWitnessTableEntry(cd)) {
asDerived().addMethod(SILDeclRef(cd, SILDeclRef::Kind::Allocator));
addAutoDiffDerivativeMethodsIfRequired(cd, SILDeclRef::Kind::Allocator);
}
}

void visitAccessorDecl(AccessorDecl *func) {
Expand All @@ -138,8 +143,10 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {

void visitFuncDecl(FuncDecl *func) {
assert(!isa<AccessorDecl>(func));
if (SILDeclRef::requiresNewWitnessTableEntry(func))
if (SILDeclRef::requiresNewWitnessTableEntry(func)) {
asDerived().addMethod(SILDeclRef(func, SILDeclRef::Kind::Func));
addAutoDiffDerivativeMethodsIfRequired(func, SILDeclRef::Kind::Func);
}
}

void visitMissingMemberDecl(MissingMemberDecl *placeholder) {
Expand All @@ -166,6 +173,26 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
void visitPoundDiagnosticDecl(PoundDiagnosticDecl *pdd) {
// We don't care about diagnostics at this stage.
}

private:
void addAutoDiffDerivativeMethodsIfRequired(AbstractFunctionDecl *AFD,
SILDeclRef::Kind kind) {
SILDeclRef declRef(AFD, kind);
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
AFD->getASTContext())));
asDerived().addMethod(declRef.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
AFD->getASTContext())));
}
}
};

} // end namespace swift
Expand Down
15 changes: 14 additions & 1 deletion lib/IRGen/GenDiffWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,28 @@ void IRGenModule::emitSILDifferentiabilityWitness(
ConstantInitBuilder builder(*this);
auto diffWitnessContents = builder.beginStruct();

// TODO(TF-1211): Uncomment assertions after upstreaming differentiation
// transform.
// The mandatory differentiation transform canonicalizes differentiability
// witnesses and ensures that JVPs/VJPs are populated.
/*
assert(dw->getJVP() &&
"Differentiability witness definition should have JVP");
assert(dw->getVJP() &&
"Differentiability witness definition should have VJP");

diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
diffWitnessContents.addBitCast(
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
*/
llvm::Constant *jvpValue = llvm::UndefValue::get(Int8PtrTy);
llvm::Constant *vjpValue = llvm::UndefValue::get(Int8PtrTy);
if (auto *jvpFn = dw->getJVP())
jvpValue = getAddrOfSILFunction(dw->getJVP(), NotForDefinition);
if (auto *vjpFn = dw->getJVP())
vjpValue = getAddrOfSILFunction(dw->getVJP(), NotForDefinition);
diffWitnessContents.addBitCast(jvpValue, Int8PtrTy);
diffWitnessContents.addBitCast(vjpValue, Int8PtrTy);

getAddrOfDifferentiabilityWitness(
dw, diffWitnessContents.finishAndCreateFuture());
Expand Down
3 changes: 1 addition & 2 deletions lib/IRGen/GenKeyPath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,7 @@ emitKeyPathComponent(IRGenModule &IGM,
auto methodProto = cast<ProtocolDecl>(dc);
auto &protoInfo = IGM.getProtocolInfo(methodProto,
ProtocolInfoKind::Full);
auto index = protoInfo.getFunctionIndex(
cast<AbstractFunctionDecl>(declRef.getDecl()));
auto index = protoInfo.getFunctionIndex(declRef);
idValue = llvm::ConstantInt::get(IGM.SizeTy, -index.getValue());
idResolution = KeyPathComponentHeader::Resolved;
}
Expand Down
18 changes: 8 additions & 10 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,20 +792,19 @@ namespace {
}

void addMethod(SILDeclRef func) {
auto decl = cast<AbstractFunctionDecl>(func.getDecl());
// If this assert needs to be changed, be sure to also change
// ProtocolDescriptorBuilder::getRequirementInfo.
assert((isa<ConstructorDecl>(decl)
? (func.kind == SILDeclRef::Kind::Allocator)
: (func.kind == SILDeclRef::Kind::Func))
&& "unexpected kind for protocol witness declaration ref");
Entries.push_back(WitnessTableEntry::forFunction(decl));
assert((isa<ConstructorDecl>(func.getDecl())
? (func.kind == SILDeclRef::Kind::Allocator)
: (func.kind == SILDeclRef::Kind::Func)) &&
"unexpected kind for protocol witness declaration ref");
Entries.push_back(WitnessTableEntry::forFunction(func));
}

void addPlaceholder(MissingMemberDecl *placeholder) {
for (auto i : range(placeholder->getNumberOfVTableEntries())) {
(void)i;
Entries.push_back(WitnessTableEntry());
Entries.push_back(WitnessTableEntry::forPlaceholder());
}
}

Expand Down Expand Up @@ -1318,8 +1317,7 @@ class AccessorConformanceInfo : public ConformanceInfo {
&& "sil witness table does not match protocol");
assert(entry.getMethodWitness().Requirement == requirement
&& "sil witness table does not match protocol");
auto piIndex =
PI.getFunctionIndex(cast<AbstractFunctionDecl>(requirement.getDecl()));
auto piIndex = PI.getFunctionIndex(requirement);
assert((size_t)piIndex.getValue() ==
Table.size() - WitnessTableFirstRequirementOffset &&
"offset doesn't match ProtocolInfo layout");
Expand Down Expand Up @@ -3277,7 +3275,7 @@ FunctionPointer irgen::emitWitnessMethodValue(IRGenFunction &IGF,

// Find the witness we're interested in.
auto &fnProtoInfo = IGF.IGM.getProtocolInfo(proto, ProtocolInfoKind::Full);
auto index = fnProtoInfo.getFunctionIndex(fn);
auto index = fnProtoInfo.getFunctionIndex(member);
llvm::Value *slot;
llvm::Value *witnessFnPtr =
emitInvariantLoadOfOpaqueWitness(IGF, wtable,
Expand Down
Loading