Skip to content

Commit d33c79e

Browse files
authored
Merge pull request #35847 from rxwei/72666310-mangle-ad-thunks
2 parents 7017362 + f9ddecf commit d33c79e

File tree

18 files changed

+380
-115
lines changed

18 files changed

+380
-115
lines changed

docs/ABI/Mangling.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ types where the metadata itself has unknown layout.)
220220
global ::= global 'Tm' // merged function
221221
global ::= entity // some identifiable thing
222222
global ::= from-type to-type generic-signature? 'TR' // reabstraction thunk
223-
global ::= from-type to-type generic-signature? 'TR' // reabstraction thunk
224223
global ::= impl-function-type type 'Tz' // objc-to-swift-async completion handler block implementation
225224
global ::= impl-function-type type 'TZ' // objc-to-swift-async completion handler block implementation (predefined by runtime)
226225
global ::= from-type to-type self-type generic-signature? 'Ty' // reabstraction thunk with dynamic 'Self' capture
@@ -230,6 +229,9 @@ types where the metadata itself has unknown layout.)
230229
global ::= type generic-signature 'TH' // key path equality
231230
global ::= type generic-signature 'Th' // key path hasher
232231
global ::= global generic-signature? 'TJ' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff function
232+
global ::= from-type to-type 'TJO' AUTODIFF-FUNCTION-KIND // autodiff self-reordering reabstraction thunk
233+
global ::= from-type 'TJS' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' INDEX-SUBSET 'P' // autodiff linear map subset parameters thunk
234+
global ::= global to-type 'TJS' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' INDEX-SUBSET 'P' // autodiff derivative function subset parameters thunk
233235

234236
global ::= protocol 'TL' // protocol requirements base descriptor
235237
global ::= assoc-type-name 'Tl' // associated type descriptor

include/swift/AST/ASTMangler.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,15 @@ class ASTMangler : public Mangler {
192192
AutoDiffLinearMapKind kind,
193193
AutoDiffConfig config);
194194

195+
/// Mangle the linear map self parameter reordering thunk the given:
196+
/// - Mangled original function declaration.
197+
/// - Linear map kind.
198+
/// - Derivative function configuration: parameter/result indices and
199+
/// derivative generic signature.
200+
std::string mangleAutoDiffSelfReorderingReabstractionThunk(
201+
CanType fromType, CanType toType, GenericSignature signature,
202+
AutoDiffLinearMapKind linearMapKind);
203+
195204
/// Mangle the AutoDiff generated declaration for the given:
196205
/// - Generated declaration kind: linear map struct or branching trace enum.
197206
/// - Mangled original function name.

include/swift/Demangling/DemangleNodes.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)
310310
NODE(AsyncFunctionPointer)
311311
CONTEXT_NODE(AutoDiffFunction)
312312
NODE(AutoDiffFunctionKind)
313+
NODE(AutoDiffSelfReorderingReabstractionThunk)
314+
NODE(AutoDiffSubsetParametersThunk)
313315
NODE(IndexSubset)
314316

315317
#undef CONTEXT_NODE

include/swift/Demangling/Demangler.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,10 @@ class Demangler : public NodeFactory {
569569

570570
NodePointer demangleTypeMangling();
571571
NodePointer demangleSymbolicReference(unsigned char rawKind);
572+
NodePointer demangleAutoDiffFunction();
572573
NodePointer demangleAutoDiffFunctionKind();
574+
NodePointer demangleAutoDiffSubsetParametersThunk();
575+
NodePointer demangleAutoDiffSelfReorderingReabstractionThunk();
573576
NodePointer demangleIndexSubset();
574577

575578
bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,

include/swift/SILOptimizer/Utils/DifferentiationMangler.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,29 @@ class DifferentiationMangler : public ASTMangler {
2727
public:
2828
DifferentiationMangler() {}
2929
/// Returns the mangled name for a differentiation function of the given kind.
30-
std::string mangle(SILFunction *originalFunction,
31-
Demangle::AutoDiffFunctionKind kind,
32-
AutoDiffConfig config);
30+
std::string mangleAutoDiffFunction(StringRef originalName,
31+
Demangle::AutoDiffFunctionKind kind,
32+
AutoDiffConfig config);
3333
/// Returns the mangled name for a derivative function of the given kind.
34-
std::string mangleDerivativeFunction(SILFunction *originalFunction,
34+
std::string mangleDerivativeFunction(StringRef originalName,
3535
AutoDiffDerivativeFunctionKind kind,
3636
AutoDiffConfig config);
3737
/// Returns the mangled name for a linear map of the given kind.
38-
std::string mangleLinearMap(SILFunction *originalFunction,
38+
std::string mangleLinearMap(StringRef originalName,
3939
AutoDiffLinearMapKind kind,
4040
AutoDiffConfig config);
41+
/// Returns the mangled name for a derivative function subset parameters
42+
/// thunk.
43+
std::string mangleDerivativeFunctionSubsetParametersThunk(
44+
StringRef originalName, CanType toType,
45+
AutoDiffDerivativeFunctionKind linearMapKind,
46+
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
47+
IndexSubset *toParamIndices);
48+
/// Returns the mangled name for a linear map subset parameters thunk.
49+
std::string mangleLinearMapSubsetParametersThunk(
50+
CanType fromType, AutoDiffLinearMapKind linearMapKind,
51+
IndexSubset *fromParamIndices, IndexSubset *fromResultIndices,
52+
IndexSubset *toParamIndices);
4153
};
4254

4355
} // end namespace Mangle

lib/AST/ASTMangler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,19 @@ void ASTMangler::appendAutoDiffFunctionParts(char functionKindCode,
448448
appendOperator("r");
449449
}
450450

451+
std::string ASTMangler::mangleAutoDiffSelfReorderingReabstractionThunk(
452+
CanType fromType, CanType toType, GenericSignature signature,
453+
AutoDiffLinearMapKind linearMapKind) {
454+
beginMangling();
455+
appendType(fromType);
456+
appendType(toType);
457+
if (signature)
458+
appendGenericSignature(signature);
459+
auto kindCode = (char)getAutoDiffFunctionKind(linearMapKind);
460+
appendOperator("TJO", StringRef(&kindCode, 1));
461+
return finalize();
462+
}
463+
451464
/// Mangle the index subset.
452465
void ASTMangler::appendIndexSubset(IndexSubset *indices) {
453466
Buffer << indices->getString();

lib/Demangling/Demangler.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,34 +2498,74 @@ NodePointer Demangler::demangleThunkOrSpecialization() {
24982498
return createNode(Node::Kind::OutlinedBridgedMethod, Params);
24992499
}
25002500
case 'u': return createNode(Node::Kind::AsyncFunctionPointer);
2501-
case 'J': {
2502-
auto result = createNode(Node::Kind::AutoDiffFunction);
2503-
auto optionalGenSig = popNode(Node::Kind::DependentGenericSignature);
2504-
auto original = popNode();
2505-
result = addChild(result, original);
2506-
addChild(result, optionalGenSig);
2507-
auto kind = demangleAutoDiffFunctionKind();
2508-
if (!kind)
2509-
return nullptr;
2510-
result = addChild(result, kind);
2511-
result = addChild(result, demangleIndexSubset());
2512-
if (!nextIf('p')) return nullptr;
2513-
result = addChild(result, demangleIndexSubset());
2514-
if (!nextIf('r')) return nullptr;
2515-
return result;
2516-
}
2501+
case 'J':
2502+
switch (peekChar()) {
2503+
case 'S':
2504+
return demangleAutoDiffSubsetParametersThunk();
2505+
case 'O':
2506+
return demangleAutoDiffSelfReorderingReabstractionThunk();
2507+
}
2508+
return demangleAutoDiffFunction();
25172509
default:
25182510
return nullptr;
25192511
}
25202512
}
25212513

2514+
NodePointer Demangler::demangleAutoDiffFunction() {
2515+
auto result = createNode(Node::Kind::AutoDiffFunction);
2516+
while (auto *originalNode = popNode())
2517+
result = addChild(result, originalNode);
2518+
result->reverseChildren();
2519+
auto kind = demangleAutoDiffFunctionKind();
2520+
result = addChild(result, kind);
2521+
result = addChild(result, demangleIndexSubset());
2522+
if (!nextIf('p'))
2523+
return nullptr;
2524+
result = addChild(result, demangleIndexSubset());
2525+
if (!nextIf('r'))
2526+
return nullptr;
2527+
return result;
2528+
}
2529+
25222530
NodePointer Demangler::demangleAutoDiffFunctionKind() {
25232531
char kind = nextChar();
25242532
if (kind != 'f' && kind != 'r' && kind != 'd' && kind != 'p')
25252533
return nullptr;
25262534
return createNode(Node::Kind::AutoDiffFunctionKind, kind);
25272535
}
25282536

2537+
NodePointer Demangler::demangleAutoDiffSubsetParametersThunk() {
2538+
nextChar();
2539+
auto result = createNode(Node::Kind::AutoDiffSubsetParametersThunk);
2540+
while (auto *node = popNode())
2541+
result = addChild(result, node);
2542+
result->reverseChildren();
2543+
auto kind = demangleAutoDiffFunctionKind();
2544+
result = addChild(result, kind);
2545+
result = addChild(result, demangleIndexSubset());
2546+
if (!nextIf('p'))
2547+
return nullptr;
2548+
result = addChild(result, demangleIndexSubset());
2549+
if (!nextIf('r'))
2550+
return nullptr;
2551+
result = addChild(result, demangleIndexSubset());
2552+
if (!nextIf('P'))
2553+
return nullptr;
2554+
return result;
2555+
}
2556+
2557+
NodePointer Demangler::demangleAutoDiffSelfReorderingReabstractionThunk() {
2558+
nextChar();
2559+
auto result = createNode(
2560+
Node::Kind::AutoDiffSelfReorderingReabstractionThunk);
2561+
addChild(result, popNode(Node::Kind::DependentGenericSignature));
2562+
result = addChild(result, popNode(Node::Kind::Type));
2563+
result = addChild(result, popNode(Node::Kind::Type));
2564+
result->reverseChildren();
2565+
result = addChild(result, demangleAutoDiffFunctionKind());
2566+
return result;
2567+
}
2568+
25292569
NodePointer Demangler::demangleIndexSubset() {
25302570
std::string str;
25312571
for (auto c = peekChar(); c == 'S' || c == 'U'; c = peekChar()) {

lib/Demangling/NodePrinter.cpp

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ class NodePrinter {
566566
case Node::Kind::CanonicalPrespecializedGenericTypeCachingOnceToken:
567567
case Node::Kind::AsyncFunctionPointer:
568568
case Node::Kind::AutoDiffFunction:
569+
case Node::Kind::AutoDiffSelfReorderingReabstractionThunk:
570+
case Node::Kind::AutoDiffSubsetParametersThunk:
569571
case Node::Kind::AutoDiffFunctionKind:
570572
case Node::Kind::IndexSubset:
571573
return false;
@@ -1738,17 +1740,29 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
17381740
return nullptr;
17391741
}
17401742
case Node::Kind::AutoDiffFunction: {
1741-
auto childIt = Node->begin();
1742-
auto original = *childIt++;
1743-
NodePointer optionalGenSig =
1744-
(*childIt)->getKind() == Node::Kind::DependentGenericSignature
1745-
? *childIt++ : nullptr;
1746-
auto kind = *childIt++;
1747-
auto paramIndices = *childIt++;
1748-
auto resultIndices = *childIt++;
1743+
unsigned prefixEndIndex = 0;
1744+
while (prefixEndIndex != Node->getNumChildren() &&
1745+
Node->getChild(prefixEndIndex)->getKind()
1746+
!= Node::Kind::AutoDiffFunctionKind)
1747+
++prefixEndIndex;
1748+
auto kind = Node->getChild(prefixEndIndex);
1749+
auto paramIndices = Node->getChild(prefixEndIndex + 1);
1750+
auto resultIndices = Node->getChild(prefixEndIndex + 2);
17491751
print(kind);
17501752
Printer << " of ";
1751-
print(original);
1753+
NodePointer optionalGenSig = nullptr;
1754+
for (unsigned i = 0; i < prefixEndIndex; ++i) {
1755+
// The last node may be a generic signature. If so, print it later.
1756+
if (i == prefixEndIndex - 1 &&
1757+
Node->getChild(i)->getKind()
1758+
== Node::Kind::DependentGenericSignature) {
1759+
optionalGenSig = Node->getChild(i);
1760+
break;
1761+
}
1762+
print(Node->getChild(i));
1763+
}
1764+
if (Options.ShortenThunk)
1765+
return nullptr;
17521766
Printer << " with respect to parameters ";
17531767
print(paramIndices);
17541768
Printer << " and results ";
@@ -1759,6 +1773,61 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
17591773
}
17601774
return nullptr;
17611775
}
1776+
case Node::Kind::AutoDiffSelfReorderingReabstractionThunk: {
1777+
Printer << "autodiff self-reordering reabstraction thunk ";
1778+
auto childIt = Node->begin();
1779+
auto fromType = *childIt++;
1780+
auto toType = *childIt++;
1781+
if (Options.ShortenThunk) {
1782+
Printer << "for ";
1783+
print(fromType);
1784+
return nullptr;
1785+
}
1786+
NodePointer optionalGenSig =
1787+
(*childIt)->getKind() == Node::Kind::DependentGenericSignature
1788+
? *childIt++ : nullptr;
1789+
Printer << "for ";
1790+
print(*childIt++); // kind
1791+
if (optionalGenSig) {
1792+
print(optionalGenSig);
1793+
Printer << ' ';
1794+
}
1795+
Printer << " from ";
1796+
print(fromType);
1797+
Printer << " to ";
1798+
print(toType);
1799+
return nullptr;
1800+
}
1801+
case Node::Kind::AutoDiffSubsetParametersThunk: {
1802+
Printer << "autodiff subset parameters thunk for ";
1803+
auto currentIndex = Node->getNumChildren() - 1;
1804+
auto toParamIndices = Node->getChild(currentIndex--);
1805+
auto resultIndices = Node->getChild(currentIndex--);
1806+
auto paramIndices = Node->getChild(currentIndex--);
1807+
auto kind = Node->getChild(currentIndex--);
1808+
print(kind);
1809+
Printer << " from ";
1810+
// Print the "from" thing.
1811+
if (currentIndex == 0) {
1812+
print(Node->getFirstChild()); // the "from" type
1813+
} else {
1814+
for (unsigned i = 0; i < currentIndex; ++i) // the "from" global
1815+
print(Node->getChild(i));
1816+
}
1817+
if (Options.ShortenThunk)
1818+
return nullptr;
1819+
Printer << " with respect to parameters ";
1820+
print(paramIndices);
1821+
Printer << " and results ";
1822+
print(resultIndices);
1823+
Printer << " to parameters ";
1824+
print(toParamIndices);
1825+
if (currentIndex > 0) {
1826+
Printer << " of type ";
1827+
print(Node->getChild(currentIndex)); // "to" type
1828+
}
1829+
return nullptr;
1830+
}
17621831
case Node::Kind::AutoDiffFunctionKind: {
17631832
auto kind = (AutoDiffFunctionKind)Node->getIndex();
17641833
switch (kind) {

lib/Demangling/OldRemangler.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,14 @@ void Remangler::mangleAutoDiffFunction(Node *node, EntityContext &ctx) {
748748
Buffer << "<autodiff-function>";
749749
}
750750

751+
void Remangler::mangleAutoDiffSelfReorderingReabstractionThunk(Node *node) {
752+
Buffer << "<autodiff-self-reordering-reabstraction-thunk>";
753+
}
754+
755+
void Remangler::mangleAutoDiffSubsetParametersThunk(Node *node) {
756+
Buffer << "<autodiff-subset-parameters-thunk>";
757+
}
758+
751759
void Remangler::mangleAutoDiffFunctionKind(Node *node) {
752760
Buffer << "<autodiff-function-kind>";
753761
}

lib/Demangling/Remangler.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,15 +2116,40 @@ void Remangler::mangleReabstractionThunkHelperWithSelf(Node *node) {
21162116

21172117
void Remangler::mangleAutoDiffFunction(Node *node) {
21182118
auto childIt = node->begin();
2119-
mangle(*childIt++); // original
2119+
while (childIt != node->end() &&
2120+
(*childIt)->getKind() != Node::Kind::AutoDiffFunctionKind)
2121+
mangle(*childIt++);
2122+
Buffer << "TJ";
2123+
mangle(*childIt++); // kind
2124+
mangle(*childIt++); // parameter indices
2125+
Buffer << 'p';
2126+
mangle(*childIt++); // result indices
2127+
Buffer << 'r';
2128+
}
2129+
2130+
void Remangler::mangleAutoDiffSelfReorderingReabstractionThunk(Node *node) {
2131+
auto childIt = node->begin();
2132+
mangle(*childIt++); // from type
2133+
mangle(*childIt++); // to type
21202134
if ((*childIt)->getKind() == Node::Kind::DependentGenericSignature)
21212135
mangleDependentGenericSignature(*childIt++);
2122-
Buffer << "TJ";
2136+
Buffer << "TJO";
2137+
mangle(*childIt++); // kind
2138+
}
2139+
2140+
void Remangler::mangleAutoDiffSubsetParametersThunk(Node *node) {
2141+
auto childIt = node->begin();
2142+
while (childIt != node->end() &&
2143+
(*childIt)->getKind() != Node::Kind::AutoDiffFunctionKind)
2144+
mangle(*childIt++);
2145+
Buffer << "TJS";
21232146
mangle(*childIt++); // kind
21242147
mangle(*childIt++); // parameter indices
21252148
Buffer << 'p';
21262149
mangle(*childIt++); // result indices
21272150
Buffer << 'r';
2151+
mangle(*childIt++); // to parameter indices
2152+
Buffer << 'P';
21282153
}
21292154

21302155
void Remangler::mangleAutoDiffFunctionKind(Node *node) {

0 commit comments

Comments
 (0)