Skip to content

[AutoDiff] Fix TBDGen issues for derivative dispatch thunks and method descriptors. #35667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)

// Added in Swift 5.5
NODE(AsyncFunctionPointer)
NODE(AutoDiffFunction)
CONTEXT_NODE(AutoDiffFunction)
NODE(AutoDiffFunctionKind)
NODE(IndexSubset)

Expand Down
39 changes: 39 additions & 0 deletions include/swift/IRGen/Linking.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class LinkEntity {
/// or a class.
DispatchThunk,

/// A derivative method dispatch thunk. The pointer is a
/// AbstractFunctionDecl* inside a protocol or a class, and the secondary
/// pointer is an AutoDiffDerivativeFunctionIdentifier*.
DispatchThunkDerivative,

/// A method dispatch thunk for an initializing constructor. The pointer
/// is a ConstructorDecl* inside a class.
DispatchThunkInitializer,
Expand All @@ -152,6 +157,11 @@ class LinkEntity {
/// or a class.
MethodDescriptor,

/// A derivative method descriptor. The pointer is a AbstractFunctionDecl*
/// inside a protocol or a class, and the secondary pointer is an
/// AutoDiffDerivativeFunctionIdentifier*.
MethodDescriptorDerivative,

/// A method descriptor for an initializing constructor. The pointer
/// is a ConstructorDecl* inside a class.
MethodDescriptorInitializer,
Expand Down Expand Up @@ -618,6 +628,16 @@ class LinkEntity {
static LinkEntity forDispatchThunk(SILDeclRef declRef) {
assert(isValidResilientMethodRef(declRef));

if (declRef.isAutoDiffDerivativeFunction()) {
LinkEntity entity;
// The derivative function for any decl is always a method (not an
// initializer).
entity.setForDecl(Kind::DispatchThunkDerivative, declRef.getDecl());
entity.SecondaryPointer =
declRef.getAutoDiffDerivativeFunctionIdentifier();
return entity;
}

LinkEntity::Kind kind;
switch (declRef.kind) {
case SILDeclRef::Kind::Func:
Expand All @@ -641,6 +661,16 @@ class LinkEntity {
static LinkEntity forMethodDescriptor(SILDeclRef declRef) {
assert(isValidResilientMethodRef(declRef));

if (declRef.isAutoDiffDerivativeFunction()) {
LinkEntity entity;
// The derivative function for any decl is always a method (not an
// initializer).
entity.setForDecl(Kind::MethodDescriptorDerivative, declRef.getDecl());
entity.SecondaryPointer =
declRef.getAutoDiffDerivativeFunctionIdentifier();
return entity;
}

LinkEntity::Kind kind;
switch (declRef.kind) {
case SILDeclRef::Kind::Func:
Expand Down Expand Up @@ -1263,6 +1293,15 @@ class LinkEntity {
assert(getKind() == Kind::AssociatedTypeWitnessTableAccessFunction);
return reinterpret_cast<ProtocolDecl*>(Pointer);
}

AutoDiffDerivativeFunctionIdentifier *
getAutoDiffDerivativeFunctionIdentifier() const {
assert(getKind() == Kind::DispatchThunkDerivative ||
getKind() == Kind::MethodDescriptorDerivative);
return reinterpret_cast<AutoDiffDerivativeFunctionIdentifier*>(
SecondaryPointer);
}

bool isDynamicallyReplaceable() const {
assert(getKind() == Kind::SILFunction);
return LINKENTITY_GET_FIELD(Data, IsDynamicallyReplaceableImpl);
Expand Down
50 changes: 20 additions & 30 deletions include/swift/SIL/SILVTableVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,7 @@ template <class T> class SILVTableVisitor {
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
maybeAddEntry(constant);

for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
maybeAddEntry(vjpConstant);
}
maybeAddAutoDiffDerivativeMethods(constant);
}

void maybeAddConstructor(ConstructorDecl *cd) {
Expand All @@ -72,21 +58,7 @@ template <class T> class SILVTableVisitor {
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
maybeAddEntry(constant);

for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(jvpConstant);

auto vjpConstant = constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
maybeAddEntry(vjpConstant);
}
maybeAddAutoDiffDerivativeMethods(constant);
}

void maybeAddAccessors(AbstractStorageDecl *asd) {
Expand Down Expand Up @@ -142,6 +114,24 @@ template <class T> class SILVTableVisitor {
asDerived().addPlaceholder(placeholder);
}

void maybeAddAutoDiffDerivativeMethods(SILDeclRef constant) {
auto *D = constant.getDecl();
for (auto *diffAttr : D->getAttrs().getAttributes<DifferentiableAttr>()) {
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::JVP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
D->getASTContext())));
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
AutoDiffDerivativeFunctionIdentifier::get(
AutoDiffDerivativeFunctionKind::VJP,
diffAttr->getParameterIndices(),
diffAttr->getDerivativeGenericSignature(),
D->getASTContext())));
}
}

protected:
void addVTableEntries(ClassDecl *theClass) {
// Imported classes do not have a vtable.
Expand Down
2 changes: 1 addition & 1 deletion lib/Demangling/OldRemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ void Remangler::mangleReabstractionThunk(Node *node) {
Buffer << "<reabstraction-thunk>";
}

void Remangler::mangleAutoDiffFunction(Node *node) {
void Remangler::mangleAutoDiffFunction(Node *node, EntityContext &ctx) {
Buffer << "<autodiff-function>";
}

Expand Down
31 changes: 31 additions & 0 deletions lib/IRGen/IRGenMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "IRGenModule.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ProtocolAssociations.h"
#include "swift/IRGen/ValueWitness.h"
#include "llvm/Support/SaveAndRestore.h"
Expand Down Expand Up @@ -51,6 +52,21 @@ class IRGenMangler : public Mangle::ASTMangler {
return finalize();
}

std::string mangleDerivativeDispatchThunk(
const AbstractFunctionDecl *func,
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
beginManglingWithAutoDiffOriginalFunction(func);
auto kindCode =
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
AutoDiffConfig config(
derivativeId->getParameterIndices(),
IndexSubset::get(func->getASTContext(), 1, {0}),
derivativeId->getDerivativeGenericSignature());
appendAutoDiffFunctionParts(kindCode, config);
appendOperator("Tj");
return finalize();
}

std::string mangleConstructorDispatchThunk(const ConstructorDecl *ctor,
bool isAllocating) {
beginMangling();
Expand All @@ -66,6 +82,21 @@ class IRGenMangler : public Mangle::ASTMangler {
return finalize();
}

std::string mangleDerivativeMethodDescriptor(
const AbstractFunctionDecl *func,
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
beginManglingWithAutoDiffOriginalFunction(func);
auto kindCode =
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
AutoDiffConfig config(
derivativeId->getParameterIndices(),
IndexSubset::get(func->getASTContext(), 1, {0}),
derivativeId->getDerivativeGenericSignature());
appendAutoDiffFunctionParts(kindCode, config);
appendOperator("Tq");
return finalize();
}

std::string mangleConstructorMethodDescriptor(const ConstructorDecl *ctor,
bool isAllocating) {
beginMangling();
Expand Down
21 changes: 21 additions & 0 deletions lib/IRGen/Linking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ std::string LinkEntity::mangleAsString() const {
return mangler.mangleDispatchThunk(func);
}

case Kind::DispatchThunkDerivative: {
auto *func = cast<AbstractFunctionDecl>(getDecl());
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
return mangler.mangleDerivativeDispatchThunk(func, derivativeId);
}

case Kind::DispatchThunkInitializer: {
auto *ctor = cast<ConstructorDecl>(getDecl());
return mangler.mangleConstructorDispatchThunk(ctor,
Expand All @@ -121,6 +127,12 @@ std::string LinkEntity::mangleAsString() const {
return mangler.mangleMethodDescriptor(func);
}

case Kind::MethodDescriptorDerivative: {
auto *func = cast<AbstractFunctionDecl>(getDecl());
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
return mangler.mangleDerivativeMethodDescriptor(func, derivativeId);
}

case Kind::MethodDescriptorInitializer: {
auto *ctor = cast<ConstructorDecl>(getDecl());
return mangler.mangleConstructorMethodDescriptor(ctor,
Expand Down Expand Up @@ -460,9 +472,11 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {

switch (getKind()) {
case Kind::DispatchThunk:
case Kind::DispatchThunkDerivative:
case Kind::DispatchThunkInitializer:
case Kind::DispatchThunkAllocator:
case Kind::MethodDescriptor:
case Kind::MethodDescriptorDerivative:
case Kind::MethodDescriptorInitializer:
case Kind::MethodDescriptorAllocator: {
auto *decl = getDecl();
Expand Down Expand Up @@ -742,12 +756,14 @@ bool LinkEntity::isContextDescriptor() const {
case Kind::AsyncFunctionPointerAST:
case Kind::PropertyDescriptor:
case Kind::DispatchThunk:
case Kind::DispatchThunkDerivative:
case Kind::DispatchThunkInitializer:
case Kind::DispatchThunkAllocator:
case Kind::DispatchThunkAsyncFunctionPointer:
case Kind::DispatchThunkInitializerAsyncFunctionPointer:
case Kind::DispatchThunkAllocatorAsyncFunctionPointer:
case Kind::MethodDescriptor:
case Kind::MethodDescriptorDerivative:
case Kind::MethodDescriptorInitializer:
case Kind::MethodDescriptorAllocator:
case Kind::MethodLookupFunction:
Expand Down Expand Up @@ -892,6 +908,7 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
case Kind::MethodDescriptor:
case Kind::MethodDescriptorInitializer:
case Kind::MethodDescriptorAllocator:
case Kind::MethodDescriptorDerivative:
return IGM.MethodDescriptorStructTy;
case Kind::DynamicallyReplaceableFunctionKey:
case Kind::OpaqueTypeDescriptorAccessorKey:
Expand Down Expand Up @@ -1020,9 +1037,11 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {

case Kind::AsyncFunctionPointerAST:
case Kind::DispatchThunk:
case Kind::DispatchThunkDerivative:
case Kind::DispatchThunkInitializer:
case Kind::DispatchThunkAllocator:
case Kind::MethodDescriptor:
case Kind::MethodDescriptorDerivative:
case Kind::MethodDescriptorInitializer:
case Kind::MethodDescriptorAllocator:
case Kind::MethodLookupFunction:
Expand Down Expand Up @@ -1104,9 +1123,11 @@ DeclContext *LinkEntity::getDeclContextForEmission() const {
switch (getKind()) {
case Kind::AsyncFunctionPointerAST:
case Kind::DispatchThunk:
case Kind::DispatchThunkDerivative:
case Kind::DispatchThunkInitializer:
case Kind::DispatchThunkAllocator:
case Kind::MethodDescriptor:
case Kind::MethodDescriptorDerivative:
case Kind::MethodDescriptorInitializer:
case Kind::MethodDescriptorAllocator:
case Kind::MethodLookupFunction:
Expand Down
17 changes: 15 additions & 2 deletions test/AutoDiff/TBD/derivative_symbols.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,22 @@ extension Array where Element == Struct {
}
}

// SR-13866: Dispatch thunks and method descriptor mangling.
public protocol P: Differentiable {
@differentiable(wrt: self)
@differentiable(wrt: (self, x))
func method(_ x: Float) -> Float

@differentiable(wrt: self)
var property: Float { get set }

@differentiable(wrt: self)
@differentiable(wrt: (self, x))
subscript(_ x: Float) -> Float { get set }
}

/* FIXME(SR-13866): Enable the following tests once we've fixed TBDGen for dispatch
* thunks and method descriptors.
/* FIXME(rdar://73791807): Enable the following tests once we've fixed TBDGen
for derivative vtable entry thunks.
public final class Class: Differentiable {
var stored: Float

Expand Down

This file was deleted.

2 changes: 2 additions & 0 deletions test/Demangle/Inputs/manglings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,6 @@ $s4main6testityyYFTu ---> async function pointer to main.testit() async -> ()
$s13test_mangling3fooyS2f_S2ftFTJfUSSpSr ---> forward-mode derivative of test_mangling.foo(Swift.Float, Swift.Float, Swift.Float) -> Swift.Float with respect to parameters {1, 2} and results {0}
$s13test_mangling4foo21xq_x_t16_Differentiation14DifferentiableR_AA1P13TangentVectorRp_r0_lFAdERzAdER_AafGRpzAafHRQr0_lTJrSpSr ---> reverse-mode derivative of test_mangling.foo2<A, B where B: _Differentiation.Differentiable, B.TangentVector: test_mangling.P>(x: A) -> B with respect to parameters {0} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable, A.TangentVector: test_mangling.P, B.TangentVector: test_mangling.P>
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSr ---> pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSrTj ---> dispatch thunk of pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSrTq ---> method descriptor for pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
$s5async1hyyS2iJXEF ---> async.h(@concurrent (Swift.Int) -> Swift.Int) -> ()