Skip to content

Commit 9aed7ce

Browse files
authored
Merge pull request #35667 from rxwei/more-ad-thunks
2 parents 450bc77 + 2f883a3 commit 9aed7ce

File tree

9 files changed

+130
-79
lines changed

9 files changed

+130
-79
lines changed

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)
309309

310310
// Added in Swift 5.5
311311
NODE(AsyncFunctionPointer)
312-
NODE(AutoDiffFunction)
312+
CONTEXT_NODE(AutoDiffFunction)
313313
NODE(AutoDiffFunctionKind)
314314
NODE(IndexSubset)
315315

include/swift/IRGen/Linking.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class LinkEntity {
126126
/// or a class.
127127
DispatchThunk,
128128

129+
/// A derivative method dispatch thunk. The pointer is a
130+
/// AbstractFunctionDecl* inside a protocol or a class, and the secondary
131+
/// pointer is an AutoDiffDerivativeFunctionIdentifier*.
132+
DispatchThunkDerivative,
133+
129134
/// A method dispatch thunk for an initializing constructor. The pointer
130135
/// is a ConstructorDecl* inside a class.
131136
DispatchThunkInitializer,
@@ -152,6 +157,11 @@ class LinkEntity {
152157
/// or a class.
153158
MethodDescriptor,
154159

160+
/// A derivative method descriptor. The pointer is a AbstractFunctionDecl*
161+
/// inside a protocol or a class, and the secondary pointer is an
162+
/// AutoDiffDerivativeFunctionIdentifier*.
163+
MethodDescriptorDerivative,
164+
155165
/// A method descriptor for an initializing constructor. The pointer
156166
/// is a ConstructorDecl* inside a class.
157167
MethodDescriptorInitializer,
@@ -618,6 +628,16 @@ class LinkEntity {
618628
static LinkEntity forDispatchThunk(SILDeclRef declRef) {
619629
assert(isValidResilientMethodRef(declRef));
620630

631+
if (declRef.isAutoDiffDerivativeFunction()) {
632+
LinkEntity entity;
633+
// The derivative function for any decl is always a method (not an
634+
// initializer).
635+
entity.setForDecl(Kind::DispatchThunkDerivative, declRef.getDecl());
636+
entity.SecondaryPointer =
637+
declRef.getAutoDiffDerivativeFunctionIdentifier();
638+
return entity;
639+
}
640+
621641
LinkEntity::Kind kind;
622642
switch (declRef.kind) {
623643
case SILDeclRef::Kind::Func:
@@ -641,6 +661,16 @@ class LinkEntity {
641661
static LinkEntity forMethodDescriptor(SILDeclRef declRef) {
642662
assert(isValidResilientMethodRef(declRef));
643663

664+
if (declRef.isAutoDiffDerivativeFunction()) {
665+
LinkEntity entity;
666+
// The derivative function for any decl is always a method (not an
667+
// initializer).
668+
entity.setForDecl(Kind::MethodDescriptorDerivative, declRef.getDecl());
669+
entity.SecondaryPointer =
670+
declRef.getAutoDiffDerivativeFunctionIdentifier();
671+
return entity;
672+
}
673+
644674
LinkEntity::Kind kind;
645675
switch (declRef.kind) {
646676
case SILDeclRef::Kind::Func:
@@ -1263,6 +1293,15 @@ class LinkEntity {
12631293
assert(getKind() == Kind::AssociatedTypeWitnessTableAccessFunction);
12641294
return reinterpret_cast<ProtocolDecl*>(Pointer);
12651295
}
1296+
1297+
AutoDiffDerivativeFunctionIdentifier *
1298+
getAutoDiffDerivativeFunctionIdentifier() const {
1299+
assert(getKind() == Kind::DispatchThunkDerivative ||
1300+
getKind() == Kind::MethodDescriptorDerivative);
1301+
return reinterpret_cast<AutoDiffDerivativeFunctionIdentifier*>(
1302+
SecondaryPointer);
1303+
}
1304+
12661305
bool isDynamicallyReplaceable() const {
12671306
assert(getKind() == Kind::SILFunction);
12681307
return LINKENTITY_GET_FIELD(Data, IsDynamicallyReplaceableImpl);

include/swift/SIL/SILVTableVisitor.h

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,7 @@ template <class T> class SILVTableVisitor {
4545
SILDeclRef constant(fd, SILDeclRef::Kind::Func);
4646
maybeAddEntry(constant);
4747

48-
for (auto *diffAttr : fd->getAttrs().getAttributes<DifferentiableAttr>()) {
49-
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
50-
AutoDiffDerivativeFunctionIdentifier::get(
51-
AutoDiffDerivativeFunctionKind::JVP,
52-
diffAttr->getParameterIndices(),
53-
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
54-
maybeAddEntry(jvpConstant);
55-
56-
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
57-
AutoDiffDerivativeFunctionIdentifier::get(
58-
AutoDiffDerivativeFunctionKind::VJP,
59-
diffAttr->getParameterIndices(),
60-
diffAttr->getDerivativeGenericSignature(), fd->getASTContext()));
61-
maybeAddEntry(vjpConstant);
62-
}
48+
maybeAddAutoDiffDerivativeMethods(constant);
6349
}
6450

6551
void maybeAddConstructor(ConstructorDecl *cd) {
@@ -72,21 +58,7 @@ template <class T> class SILVTableVisitor {
7258
SILDeclRef constant(cd, SILDeclRef::Kind::Allocator);
7359
maybeAddEntry(constant);
7460

75-
for (auto *diffAttr : cd->getAttrs().getAttributes<DifferentiableAttr>()) {
76-
auto jvpConstant = constant.asAutoDiffDerivativeFunction(
77-
AutoDiffDerivativeFunctionIdentifier::get(
78-
AutoDiffDerivativeFunctionKind::JVP,
79-
diffAttr->getParameterIndices(),
80-
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
81-
maybeAddEntry(jvpConstant);
82-
83-
auto vjpConstant = constant.asAutoDiffDerivativeFunction(
84-
AutoDiffDerivativeFunctionIdentifier::get(
85-
AutoDiffDerivativeFunctionKind::VJP,
86-
diffAttr->getParameterIndices(),
87-
diffAttr->getDerivativeGenericSignature(), cd->getASTContext()));
88-
maybeAddEntry(vjpConstant);
89-
}
61+
maybeAddAutoDiffDerivativeMethods(constant);
9062
}
9163

9264
void maybeAddAccessors(AbstractStorageDecl *asd) {
@@ -142,6 +114,24 @@ template <class T> class SILVTableVisitor {
142114
asDerived().addPlaceholder(placeholder);
143115
}
144116

117+
void maybeAddAutoDiffDerivativeMethods(SILDeclRef constant) {
118+
auto *D = constant.getDecl();
119+
for (auto *diffAttr : D->getAttrs().getAttributes<DifferentiableAttr>()) {
120+
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
121+
AutoDiffDerivativeFunctionIdentifier::get(
122+
AutoDiffDerivativeFunctionKind::JVP,
123+
diffAttr->getParameterIndices(),
124+
diffAttr->getDerivativeGenericSignature(),
125+
D->getASTContext())));
126+
maybeAddEntry(constant.asAutoDiffDerivativeFunction(
127+
AutoDiffDerivativeFunctionIdentifier::get(
128+
AutoDiffDerivativeFunctionKind::VJP,
129+
diffAttr->getParameterIndices(),
130+
diffAttr->getDerivativeGenericSignature(),
131+
D->getASTContext())));
132+
}
133+
}
134+
145135
protected:
146136
void addVTableEntries(ClassDecl *theClass) {
147137
// Imported classes do not have a vtable.

lib/Demangling/OldRemangler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ void Remangler::mangleReabstractionThunk(Node *node) {
744744
Buffer << "<reabstraction-thunk>";
745745
}
746746

747-
void Remangler::mangleAutoDiffFunction(Node *node) {
747+
void Remangler::mangleAutoDiffFunction(Node *node, EntityContext &ctx) {
748748
Buffer << "<autodiff-function>";
749749
}
750750

lib/IRGen/IRGenMangler.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "IRGenModule.h"
1717
#include "swift/AST/ASTMangler.h"
18+
#include "swift/AST/AutoDiff.h"
1819
#include "swift/AST/ProtocolAssociations.h"
1920
#include "swift/IRGen/ValueWitness.h"
2021
#include "llvm/Support/SaveAndRestore.h"
@@ -51,6 +52,21 @@ class IRGenMangler : public Mangle::ASTMangler {
5152
return finalize();
5253
}
5354

55+
std::string mangleDerivativeDispatchThunk(
56+
const AbstractFunctionDecl *func,
57+
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
58+
beginManglingWithAutoDiffOriginalFunction(func);
59+
auto kindCode =
60+
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
61+
AutoDiffConfig config(
62+
derivativeId->getParameterIndices(),
63+
IndexSubset::get(func->getASTContext(), 1, {0}),
64+
derivativeId->getDerivativeGenericSignature());
65+
appendAutoDiffFunctionParts(kindCode, config);
66+
appendOperator("Tj");
67+
return finalize();
68+
}
69+
5470
std::string mangleConstructorDispatchThunk(const ConstructorDecl *ctor,
5571
bool isAllocating) {
5672
beginMangling();
@@ -66,6 +82,21 @@ class IRGenMangler : public Mangle::ASTMangler {
6682
return finalize();
6783
}
6884

85+
std::string mangleDerivativeMethodDescriptor(
86+
const AbstractFunctionDecl *func,
87+
AutoDiffDerivativeFunctionIdentifier *derivativeId) {
88+
beginManglingWithAutoDiffOriginalFunction(func);
89+
auto kindCode =
90+
(char)Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
91+
AutoDiffConfig config(
92+
derivativeId->getParameterIndices(),
93+
IndexSubset::get(func->getASTContext(), 1, {0}),
94+
derivativeId->getDerivativeGenericSignature());
95+
appendAutoDiffFunctionParts(kindCode, config);
96+
appendOperator("Tq");
97+
return finalize();
98+
}
99+
69100
std::string mangleConstructorMethodDescriptor(const ConstructorDecl *ctor,
70101
bool isAllocating) {
71102
beginMangling();

lib/IRGen/Linking.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ std::string LinkEntity::mangleAsString() const {
104104
return mangler.mangleDispatchThunk(func);
105105
}
106106

107+
case Kind::DispatchThunkDerivative: {
108+
auto *func = cast<AbstractFunctionDecl>(getDecl());
109+
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
110+
return mangler.mangleDerivativeDispatchThunk(func, derivativeId);
111+
}
112+
107113
case Kind::DispatchThunkInitializer: {
108114
auto *ctor = cast<ConstructorDecl>(getDecl());
109115
return mangler.mangleConstructorDispatchThunk(ctor,
@@ -121,6 +127,12 @@ std::string LinkEntity::mangleAsString() const {
121127
return mangler.mangleMethodDescriptor(func);
122128
}
123129

130+
case Kind::MethodDescriptorDerivative: {
131+
auto *func = cast<AbstractFunctionDecl>(getDecl());
132+
auto *derivativeId = getAutoDiffDerivativeFunctionIdentifier();
133+
return mangler.mangleDerivativeMethodDescriptor(func, derivativeId);
134+
}
135+
124136
case Kind::MethodDescriptorInitializer: {
125137
auto *ctor = cast<ConstructorDecl>(getDecl());
126138
return mangler.mangleConstructorMethodDescriptor(ctor,
@@ -460,9 +472,11 @@ SILLinkage LinkEntity::getLinkage(ForDefinition_t forDefinition) const {
460472

461473
switch (getKind()) {
462474
case Kind::DispatchThunk:
475+
case Kind::DispatchThunkDerivative:
463476
case Kind::DispatchThunkInitializer:
464477
case Kind::DispatchThunkAllocator:
465478
case Kind::MethodDescriptor:
479+
case Kind::MethodDescriptorDerivative:
466480
case Kind::MethodDescriptorInitializer:
467481
case Kind::MethodDescriptorAllocator: {
468482
auto *decl = getDecl();
@@ -742,12 +756,14 @@ bool LinkEntity::isContextDescriptor() const {
742756
case Kind::AsyncFunctionPointerAST:
743757
case Kind::PropertyDescriptor:
744758
case Kind::DispatchThunk:
759+
case Kind::DispatchThunkDerivative:
745760
case Kind::DispatchThunkInitializer:
746761
case Kind::DispatchThunkAllocator:
747762
case Kind::DispatchThunkAsyncFunctionPointer:
748763
case Kind::DispatchThunkInitializerAsyncFunctionPointer:
749764
case Kind::DispatchThunkAllocatorAsyncFunctionPointer:
750765
case Kind::MethodDescriptor:
766+
case Kind::MethodDescriptorDerivative:
751767
case Kind::MethodDescriptorInitializer:
752768
case Kind::MethodDescriptorAllocator:
753769
case Kind::MethodLookupFunction:
@@ -892,6 +908,7 @@ llvm::Type *LinkEntity::getDefaultDeclarationType(IRGenModule &IGM) const {
892908
case Kind::MethodDescriptor:
893909
case Kind::MethodDescriptorInitializer:
894910
case Kind::MethodDescriptorAllocator:
911+
case Kind::MethodDescriptorDerivative:
895912
return IGM.MethodDescriptorStructTy;
896913
case Kind::DynamicallyReplaceableFunctionKey:
897914
case Kind::OpaqueTypeDescriptorAccessorKey:
@@ -1020,9 +1037,11 @@ bool LinkEntity::isWeakImported(ModuleDecl *module) const {
10201037

10211038
case Kind::AsyncFunctionPointerAST:
10221039
case Kind::DispatchThunk:
1040+
case Kind::DispatchThunkDerivative:
10231041
case Kind::DispatchThunkInitializer:
10241042
case Kind::DispatchThunkAllocator:
10251043
case Kind::MethodDescriptor:
1044+
case Kind::MethodDescriptorDerivative:
10261045
case Kind::MethodDescriptorInitializer:
10271046
case Kind::MethodDescriptorAllocator:
10281047
case Kind::MethodLookupFunction:
@@ -1104,9 +1123,11 @@ DeclContext *LinkEntity::getDeclContextForEmission() const {
11041123
switch (getKind()) {
11051124
case Kind::AsyncFunctionPointerAST:
11061125
case Kind::DispatchThunk:
1126+
case Kind::DispatchThunkDerivative:
11071127
case Kind::DispatchThunkInitializer:
11081128
case Kind::DispatchThunkAllocator:
11091129
case Kind::MethodDescriptor:
1130+
case Kind::MethodDescriptorDerivative:
11101131
case Kind::MethodDescriptorInitializer:
11111132
case Kind::MethodDescriptorAllocator:
11121133
case Kind::MethodLookupFunction:

test/AutoDiff/TBD/derivative_symbols.swift

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,22 @@ extension Array where Element == Struct {
8383
}
8484
}
8585

86+
// SR-13866: Dispatch thunks and method descriptor mangling.
87+
public protocol P: Differentiable {
88+
@differentiable(wrt: self)
89+
@differentiable(wrt: (self, x))
90+
func method(_ x: Float) -> Float
91+
92+
@differentiable(wrt: self)
93+
var property: Float { get set }
94+
95+
@differentiable(wrt: self)
96+
@differentiable(wrt: (self, x))
97+
subscript(_ x: Float) -> Float { get set }
98+
}
8699

87-
/* FIXME(SR-13866): Enable the following tests once we've fixed TBDGen for dispatch
88-
* thunks and method descriptors.
100+
/* FIXME(rdar://73791807): Enable the following tests once we've fixed TBDGen
101+
for derivative vtable entry thunks.
89102
public final class Class: Differentiable {
90103
var stored: Float
91104

test/AutoDiff/compiler_crashers/sr13866-library-evolution-mode-tbdgen-crasher-protocool-requirement.swift

Lines changed: 0 additions & 45 deletions
This file was deleted.

test/Demangle/Inputs/manglings.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,4 +375,6 @@ $s4main6testityyYFTu ---> async function pointer to main.testit() async -> ()
375375
$s13test_mangling3fooyS2f_S2ftFTJfUSSpSr ---> forward-mode derivative of test_mangling.foo(Swift.Float, Swift.Float, Swift.Float) -> Swift.Float with respect to parameters {1, 2} and results {0}
376376
$s13test_mangling4foo21xq_x_t16_Differentiation14DifferentiableR_AA1P13TangentVectorRp_r0_lFAdERzAdER_AafGRpzAafHRQr0_lTJrSpSr ---> reverse-mode derivative of test_mangling.foo2<A, B where B: _Differentiation.Differentiable, B.TangentVector: test_mangling.P>(x: A) -> B with respect to parameters {0} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable, A.TangentVector: test_mangling.P, B.TangentVector: test_mangling.P>
377377
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSr ---> pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
378+
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSrTj ---> dispatch thunk of pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
379+
$s13test_mangling3fooyS2f_xq_t16_Differentiation14DifferentiableR_r0_lFAcDRzAcDR_r0_lTJpUSSpSrTq ---> method descriptor for pullback of test_mangling.foo<A, B where B: _Differentiation.Differentiable>(Swift.Float, A, B) -> Swift.Float with respect to parameters {1, 2} and results {0} with <A, B where A: _Differentiation.Differentiable, B: _Differentiation.Differentiable>
378380
$s5async1hyyS2iJXEF ---> async.h(@concurrent (Swift.Int) -> Swift.Int) -> ()

0 commit comments

Comments
 (0)