Skip to content

Commit fd06683

Browse files
authored
Merge pull request #36772 from rxwei/75916833-noderivative
2 parents 6e2338c + fb66de6 commit fd06683

File tree

87 files changed

+464
-350
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+464
-350
lines changed

docs/ABI/Mangling.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,19 +570,19 @@ Types
570570
// they are mangled separately as part of the entity.
571571
params-type ::= empty-list // shortcut for no parameters
572572

573-
sendable ::= 'J' // @Sendable on function types
574-
async ::= 'Y' // 'async' annotation on function types
573+
async ::= 'Ya' // 'async' annotation on function types
574+
sendable ::= 'Yb' // @Sendable on function types
575575
throws ::= 'K' // 'throws' annotation on function types
576-
differentiable ::= 'jf' // @differentiable(_forward) on function type
577-
differentiable ::= 'jr' // @differentiable(reverse) on function type
578-
differentiable ::= 'jd' // @differentiable on function type
579-
differentiable ::= 'jl' // @differentiable(_linear) on function type
576+
differentiable ::= 'Yjf' // @differentiable(_forward) on function type
577+
differentiable ::= 'Yjr' // @differentiable(reverse) on function type
578+
differentiable ::= 'Yjd' // @differentiable on function type
579+
differentiable ::= 'Yjl' // @differentiable(_linear) on function type
580580

581581
type-list ::= list-type '_' list-type* // list of types
582582
type-list ::= empty-list
583583

584584
// FIXME: Consider replacing 'h' with a two-char code
585-
list-type ::= type identifier? 'z'? 'h'? 'n'? 'd'? // type with optional label, inout convention, shared convention, owned convention, and variadic specifier
585+
list-type ::= type identifier? 'Yk'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier
586586

587587
METATYPE-REPR ::= 't' // Thin metatype representation
588588
METATYPE-REPR ::= 'T' // Thick metatype representation
@@ -666,7 +666,7 @@ mangled in to disambiguate.
666666
COROUTINE-KIND ::= 'A' // yield-once coroutine
667667
COROUTINE-KIND ::= 'G' // yield-many coroutine
668668

669-
SENDABLE ::= 'h' // @Sendable
669+
SENDABLE ::= 'h' // @Sendable
670670
ASYNC ::= 'H' // @async
671671

672672
PARAM-CONVENTION ::= 'i' // indirect in

include/swift/AST/Attr.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,8 +2324,6 @@ class TypeAttributes {
23242324

23252325
Optional<Convention> ConventionArguments;
23262326

2327-
// Indicates whether the type's '@differentiable' attribute has a 'linear'
2328-
// argument.
23292327
DifferentiabilityKind differentiabilityKind =
23302328
DifferentiabilityKind::NonDifferentiable;
23312329

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1929,7 +1929,7 @@ class ParameterTypeFlags {
19291929
NonEphemeral = 1 << 2,
19301930
OwnershipShift = 3,
19311931
Ownership = 7 << OwnershipShift,
1932-
NoDerivative = 1 << 7,
1932+
NoDerivative = 1 << 6,
19331933
NumBits = 7
19341934
};
19351935
OptionSet<ParameterFlags> value;

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ NODE(AutoDiffSelfReorderingReabstractionThunk)
312312
NODE(AutoDiffSubsetParametersThunk)
313313
NODE(AutoDiffDerivativeVTableThunk)
314314
NODE(DifferentiabilityWitness)
315+
NODE(NoDerivative)
315316
NODE(IndexSubset)
316317
NODE(AsyncAwaitResumePartialFunction)
317318
NODE(AsyncSuspendResumePartialFunction)

include/swift/Demangling/Demangler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,8 @@ class Demangler : public NodeFactory {
569569

570570
NodePointer demangleTypeMangling();
571571
NodePointer demangleSymbolicReference(unsigned char rawKind);
572+
NodePointer demangleTypeAnnotation();
573+
572574
NodePointer demangleAutoDiffFunctionOrSimpleThunk(Node::Kind nodeKind);
573575
NodePointer demangleAutoDiffFunctionKind();
574576
NodePointer demangleAutoDiffSubsetParametersThunk();

include/swift/Demangling/TypeDecoder.h

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class FunctionParam {
7272
void setValueOwnership(ValueOwnership ownership) {
7373
Flags = Flags.withValueOwnership(ownership);
7474
}
75+
void setNoDerivative() { Flags = Flags.withNoDerivative(true); }
7576
void setFlags(ParameterFlags flags) { Flags = flags; };
7677

7778
FunctionParam withLabel(StringRef label) const {
@@ -737,27 +738,6 @@ class TypeDecoder {
737738
++firstChildIdx;
738739
}
739740

740-
bool isThrow = false;
741-
if (Node->getChild(firstChildIdx)->getKind()
742-
== NodeKind::ThrowsAnnotation) {
743-
isThrow = true;
744-
++firstChildIdx;
745-
}
746-
747-
bool isSendable = false;
748-
if (Node->getChild(firstChildIdx)->getKind()
749-
== NodeKind::ConcurrentFunctionType) {
750-
isSendable = true;
751-
++firstChildIdx;
752-
}
753-
754-
bool isAsync = false;
755-
if (Node->getChild(firstChildIdx)->getKind()
756-
== NodeKind::AsyncAnnotation) {
757-
isAsync = true;
758-
++firstChildIdx;
759-
}
760-
761741
FunctionMetadataDifferentiabilityKind diffKind;
762742
if (Node->getChild(firstChildIdx)->getKind() ==
763743
NodeKind::DifferentiableFunctionType) {
@@ -783,6 +763,27 @@ class TypeDecoder {
783763
++firstChildIdx;
784764
}
785765

766+
bool isThrow = false;
767+
if (Node->getChild(firstChildIdx)->getKind()
768+
== NodeKind::ThrowsAnnotation) {
769+
isThrow = true;
770+
++firstChildIdx;
771+
}
772+
773+
bool isSendable = false;
774+
if (Node->getChild(firstChildIdx)->getKind()
775+
== NodeKind::ConcurrentFunctionType) {
776+
isSendable = true;
777+
++firstChildIdx;
778+
}
779+
780+
bool isAsync = false;
781+
if (Node->getChild(firstChildIdx)->getKind()
782+
== NodeKind::AsyncAnnotation) {
783+
isAsync = true;
784+
++firstChildIdx;
785+
}
786+
786787
flags = flags.withConcurrent(isSendable)
787788
.withAsync(isAsync).withThrows(isThrow)
788789
.withDifferentiable(diffKind.isDifferentiable());
@@ -1370,33 +1371,44 @@ class TypeDecoder {
13701371
FunctionParam<BuiltType> &param) -> bool {
13711372
Demangle::NodePointer node = typeNode;
13721373

1373-
auto setOwnership = [&](ValueOwnership ownership) {
1374-
param.setValueOwnership(ownership);
1375-
node = node->getFirstChild();
1376-
hasParamFlags = true;
1377-
};
1378-
switch (node->getKind()) {
1379-
case NodeKind::InOut:
1380-
setOwnership(ValueOwnership::InOut);
1381-
break;
1374+
bool recurse = true;
1375+
while (recurse) {
1376+
switch (node->getKind()) {
1377+
case NodeKind::InOut:
1378+
param.setValueOwnership(ValueOwnership::InOut);
1379+
node = node->getFirstChild();
1380+
hasParamFlags = true;
1381+
break;
13821382

1383-
case NodeKind::Shared:
1384-
setOwnership(ValueOwnership::Shared);
1385-
break;
1383+
case NodeKind::Shared:
1384+
param.setValueOwnership(ValueOwnership::Shared);
1385+
node = node->getFirstChild();
1386+
hasParamFlags = true;
1387+
break;
13861388

1387-
case NodeKind::Owned:
1388-
setOwnership(ValueOwnership::Owned);
1389-
break;
1389+
case NodeKind::Owned:
1390+
param.setValueOwnership(ValueOwnership::Owned);
1391+
node = node->getFirstChild();
1392+
hasParamFlags = true;
1393+
break;
13901394

1391-
case NodeKind::AutoClosureType:
1392-
case NodeKind::EscapingAutoClosureType: {
1393-
param.setAutoClosure();
1394-
hasParamFlags = true;
1395-
break;
1396-
}
1395+
case NodeKind::NoDerivative:
1396+
param.setNoDerivative();
1397+
node = node->getFirstChild();
1398+
hasParamFlags = true;
1399+
break;
13971400

1398-
default:
1399-
break;
1401+
case NodeKind::AutoClosureType:
1402+
case NodeKind::EscapingAutoClosureType:
1403+
param.setAutoClosure();
1404+
hasParamFlags = true;
1405+
recurse = false;
1406+
break;
1407+
1408+
default:
1409+
recurse = false;
1410+
break;
1411+
}
14001412
}
14011413

14021414
auto paramType = decodeMangledType(node);

lib/AST/ASTMangler.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,25 +2448,25 @@ void ASTMangler::appendFunctionSignature(AnyFunctionType *fn,
24482448
appendFunctionResultType(fn->getResult(), forDecl);
24492449
appendFunctionInputType(fn->getParams(), forDecl);
24502450
if (fn->isAsync() || functionMangling == AsyncHandlerBodyMangling)
2451-
appendOperator("Y");
2451+
appendOperator("Ya");
24522452
if (fn->isSendable())
2453-
appendOperator("J");
2453+
appendOperator("Yb");
24542454
if (fn->isThrowing())
24552455
appendOperator("K");
24562456
switch (auto diffKind = fn->getDifferentiabilityKind()) {
24572457
case DifferentiabilityKind::NonDifferentiable:
24582458
break;
24592459
case DifferentiabilityKind::Forward:
2460-
appendOperator("jf");
2460+
appendOperator("Yjf");
24612461
break;
24622462
case DifferentiabilityKind::Reverse:
2463-
appendOperator("jr");
2463+
appendOperator("Yjr");
24642464
break;
24652465
case DifferentiabilityKind::Normal:
2466-
appendOperator("jd");
2466+
appendOperator("Yjd");
24672467
break;
24682468
case DifferentiabilityKind::Linear:
2469-
appendOperator("jl");
2469+
appendOperator("Yjl");
24702470
break;
24712471
}
24722472
}
@@ -2541,6 +2541,9 @@ void ASTMangler::appendTypeListElement(Identifier name, Type elementType,
25412541
else
25422542
appendType(elementType, forDecl);
25432543

2544+
if (flags.isNoDerivative()) {
2545+
appendOperator("Yk");
2546+
}
25442547
switch (flags.getValueOwnership()) {
25452548
case ValueOwnership::Default:
25462549
/* nothing */

lib/Demangling/Demangler.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,22 @@ NodePointer Demangler::demangleSymbolicReference(unsigned char rawKind) {
730730
return resolved;
731731
}
732732

733+
NodePointer Demangler::demangleTypeAnnotation() {
734+
switch (char c2 = nextChar()) {
735+
case 'a':
736+
return createNode(Node::Kind::AsyncAnnotation);
737+
case 'b':
738+
return createNode(Node::Kind::ConcurrentFunctionType);
739+
case 'j':
740+
return demangleDifferentiableFunctionType();
741+
case 'k':
742+
return createType(
743+
createWithChild(Node::Kind::NoDerivative, popTypeAndGetChild()));
744+
default:
745+
return nullptr;
746+
}
747+
}
748+
733749
NodePointer Demangler::demangleOperator() {
734750
recur:
735751
switch (unsigned char c = nextChar()) {
@@ -767,7 +783,6 @@ NodePointer Demangler::demangleOperator() {
767783
}
768784

769785
case 'I': return demangleImplFunctionType();
770-
case 'J': return createNode(Node::Kind::ConcurrentFunctionType);
771786
case 'K': return createNode(Node::Kind::ThrowsAnnotation);
772787
case 'L': return demangleLocalIdentifier();
773788
case 'M': return demangleMetatype();
@@ -782,7 +797,7 @@ NodePointer Demangler::demangleOperator() {
782797
case 'V': return demangleAnyGenericType(Node::Kind::Structure);
783798
case 'W': return demangleWitness();
784799
case 'X': return demangleSpecialType();
785-
case 'Y': return createNode(Node::Kind::AsyncAnnotation);
800+
case 'Y': return demangleTypeAnnotation();
786801
case 'Z': return createWithChild(Node::Kind::Static, popNode(isEntity));
787802
case 'a': return demangleAnyGenericType(Node::Kind::TypeAlias);
788803
case 'c': return popFunctionType(Node::Kind::FunctionType);
@@ -792,7 +807,6 @@ NodePointer Demangler::demangleOperator() {
792807
case 'h': return createType(createWithChild(Node::Kind::Shared,
793808
popTypeAndGetChild()));
794809
case 'i': return demangleSubscript();
795-
case 'j': return demangleDifferentiableFunctionType();
796810
case 'l': return demangleGenericSignature(/*hasParamCounts*/ false);
797811
case 'm': return createType(createWithChild(Node::Kind::Metatype,
798812
popNode(Node::Kind::Type)));
@@ -1254,10 +1268,10 @@ NodePointer Demangler::popFunctionType(Node::Kind kind, bool hasClangType) {
12541268
ClangType = demangleClangType();
12551269
}
12561270
addChild(FuncType, ClangType);
1271+
addChild(FuncType, popNode(Node::Kind::DifferentiableFunctionType));
12571272
addChild(FuncType, popNode(Node::Kind::ThrowsAnnotation));
12581273
addChild(FuncType, popNode(Node::Kind::ConcurrentFunctionType));
12591274
addChild(FuncType, popNode(Node::Kind::AsyncAnnotation));
1260-
addChild(FuncType, popNode(Node::Kind::DifferentiableFunctionType));
12611275

12621276
FuncType = addChild(FuncType, popFunctionParams(Node::Kind::ArgumentTuple));
12631277
FuncType = addChild(FuncType, popFunctionParams(Node::Kind::ReturnType));
@@ -1290,6 +1304,9 @@ NodePointer Demangler::popFunctionParamLabels(NodePointer Type) {
12901304
return nullptr;
12911305

12921306
unsigned FirstChildIdx = 0;
1307+
if (FuncType->getChild(FirstChildIdx)->getKind()
1308+
== Node::Kind::DifferentiableFunctionType)
1309+
++FirstChildIdx;
12931310
if (FuncType->getChild(FirstChildIdx)->getKind()
12941311
== Node::Kind::ThrowsAnnotation)
12951312
++FirstChildIdx;
@@ -1299,9 +1316,6 @@ NodePointer Demangler::popFunctionParamLabels(NodePointer Type) {
12991316
if (FuncType->getChild(FirstChildIdx)->getKind()
13001317
== Node::Kind::AsyncAnnotation)
13011318
++FirstChildIdx;
1302-
if (FuncType->getChild(FirstChildIdx)->getKind()
1303-
== Node::Kind::DifferentiableFunctionType)
1304-
++FirstChildIdx;
13051319
auto ParameterType = FuncType->getChild(FirstChildIdx);
13061320

13071321
assert(ParameterType->getKind() == Node::Kind::ArgumentTuple);

lib/Demangling/NodePrinter.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ class NodePrinter {
569569
case Node::Kind::AutoDiffSubsetParametersThunk:
570570
case Node::Kind::AutoDiffFunctionKind:
571571
case Node::Kind::DifferentiabilityWitness:
572+
case Node::Kind::NoDerivative:
572573
case Node::Kind::IndexSubset:
573574
case Node::Kind::AsyncAwaitResumePartialFunction:
574575
case Node::Kind::AsyncSuspendResumePartialFunction:
@@ -808,6 +809,12 @@ class NodePrinter {
808809
unsigned startIndex = 0;
809810
bool isSendable = false, isAsync = false, isThrows = false;
810811
auto diffKind = MangledDifferentiabilityKind::NonDifferentiable;
812+
if (node->getChild(startIndex)->getKind() ==
813+
Node::Kind::DifferentiableFunctionType) {
814+
diffKind =
815+
(MangledDifferentiabilityKind)node->getChild(startIndex)->getIndex();
816+
++startIndex;
817+
}
811818
if (node->getChild(startIndex)->getKind() == Node::Kind::ClangType) {
812819
// handled earlier
813820
++startIndex;
@@ -825,12 +832,6 @@ class NodePrinter {
825832
++startIndex;
826833
isAsync = true;
827834
}
828-
if (node->getChild(startIndex)->getKind() ==
829-
Node::Kind::DifferentiableFunctionType) {
830-
diffKind =
831-
(MangledDifferentiabilityKind)node->getChild(startIndex)->getIndex();
832-
++startIndex;
833-
}
834835

835836
switch (diffKind) {
836837
case MangledDifferentiabilityKind::Forward:
@@ -1421,6 +1422,10 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
14211422
Printer << "__owned ";
14221423
print(Node->getChild(0));
14231424
return nullptr;
1425+
case Node::Kind::NoDerivative:
1426+
Printer << "@noDerivative ";
1427+
print(Node->getChild(0));
1428+
return nullptr;
14241429
case Node::Kind::NonObjCAttribute:
14251430
Printer << "@nonobjc ";
14261431
return nullptr;

lib/Demangling/OldDemangler.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,14 @@ class OldDemangler {
20632063
inout->addChild(type, Factory);
20642064
return inout;
20652065
}
2066+
if (c == 'k') {
2067+
auto noDerivative = Factory.createNode(Node::Kind::NoDerivative);
2068+
auto type = demangleTypeImpl();
2069+
if (!type)
2070+
return nullptr;
2071+
noDerivative->addChild(type, Factory);
2072+
return noDerivative;
2073+
}
20662074
if (c == 'S') {
20672075
return demangleSubstitutionIndex();
20682076
}

lib/Demangling/OldRemangler.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,6 +1485,11 @@ void Remangler::mangleInOut(Node *node) {
14851485
mangleSingleChildNode(node); // type
14861486
}
14871487

1488+
void Remangler::mangleNoDerivative(Node *node) {
1489+
Buffer << 'k';
1490+
mangleSingleChildNode(node); // type
1491+
}
1492+
14881493
void Remangler::mangleTuple(Node *node) {
14891494
size_t NumElems = node->getNumChildren();
14901495
if (NumElems > 0 &&

0 commit comments

Comments
 (0)