Skip to content

Fix mangling of '@noDerivative' and unify function attribute mangling operators. #36772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -570,19 +570,19 @@ Types
// they are mangled separately as part of the entity.
params-type ::= empty-list // shortcut for no parameters

sendable ::= 'J' // @Sendable on function types
async ::= 'Y' // 'async' annotation on function types
async ::= 'Ya' // 'async' annotation on function types
sendable ::= 'Yb' // @Sendable on function types
throws ::= 'K' // 'throws' annotation on function types
differentiable ::= 'jf' // @differentiable(_forward) on function type
differentiable ::= 'jr' // @differentiable(reverse) on function type
differentiable ::= 'jd' // @differentiable on function type
differentiable ::= 'jl' // @differentiable(_linear) on function type
differentiable ::= 'Yjf' // @differentiable(_forward) on function type
differentiable ::= 'Yjr' // @differentiable(reverse) on function type
differentiable ::= 'Yjd' // @differentiable on function type
differentiable ::= 'Yjl' // @differentiable(_linear) on function type

type-list ::= list-type '_' list-type* // list of types
type-list ::= empty-list

// FIXME: Consider replacing 'h' with a two-char code
list-type ::= type identifier? 'z'? 'h'? 'n'? 'd'? // type with optional label, inout convention, shared convention, owned convention, and variadic specifier
list-type ::= type identifier? 'Yk'? 'z'? 'h'? 'n'? 'd'? // type with optional label, '@noDerivative', inout convention, shared convention, owned convention, and variadic specifier

METATYPE-REPR ::= 't' // Thin metatype representation
METATYPE-REPR ::= 'T' // Thick metatype representation
Expand Down Expand Up @@ -666,7 +666,7 @@ mangled in to disambiguate.
COROUTINE-KIND ::= 'A' // yield-once coroutine
COROUTINE-KIND ::= 'G' // yield-many coroutine

SENDABLE ::= 'h' // @Sendable
SENDABLE ::= 'h' // @Sendable
ASYNC ::= 'H' // @async

PARAM-CONVENTION ::= 'i' // indirect in
Expand Down
2 changes: 0 additions & 2 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2324,8 +2324,6 @@ class TypeAttributes {

Optional<Convention> ConventionArguments;

// Indicates whether the type's '@differentiable' attribute has a 'linear'
// argument.
DifferentiabilityKind differentiabilityKind =
DifferentiabilityKind::NonDifferentiable;

Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1929,7 +1929,7 @@ class ParameterTypeFlags {
NonEphemeral = 1 << 2,
OwnershipShift = 3,
Ownership = 7 << OwnershipShift,
NoDerivative = 1 << 7,
NoDerivative = 1 << 6,
NumBits = 7
};
OptionSet<ParameterFlags> value;
Expand Down
1 change: 1 addition & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ NODE(AutoDiffSelfReorderingReabstractionThunk)
NODE(AutoDiffSubsetParametersThunk)
NODE(AutoDiffDerivativeVTableThunk)
NODE(DifferentiabilityWitness)
NODE(NoDerivative)
NODE(IndexSubset)
NODE(AsyncAwaitResumePartialFunction)
NODE(AsyncSuspendResumePartialFunction)
Expand Down
2 changes: 2 additions & 0 deletions include/swift/Demangling/Demangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ class Demangler : public NodeFactory {

NodePointer demangleTypeMangling();
NodePointer demangleSymbolicReference(unsigned char rawKind);
NodePointer demangleTypeAnnotation();

NodePointer demangleAutoDiffFunctionOrSimpleThunk(Node::Kind nodeKind);
NodePointer demangleAutoDiffFunctionKind();
NodePointer demangleAutoDiffSubsetParametersThunk();
Expand Down
100 changes: 56 additions & 44 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class FunctionParam {
void setValueOwnership(ValueOwnership ownership) {
Flags = Flags.withValueOwnership(ownership);
}
void setNoDerivative() { Flags = Flags.withNoDerivative(true); }
void setFlags(ParameterFlags flags) { Flags = flags; };

FunctionParam withLabel(StringRef label) const {
Expand Down Expand Up @@ -737,27 +738,6 @@ class TypeDecoder {
++firstChildIdx;
}

bool isThrow = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::ThrowsAnnotation) {
isThrow = true;
++firstChildIdx;
}

bool isSendable = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::ConcurrentFunctionType) {
isSendable = true;
++firstChildIdx;
}

bool isAsync = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::AsyncAnnotation) {
isAsync = true;
++firstChildIdx;
}

FunctionMetadataDifferentiabilityKind diffKind;
if (Node->getChild(firstChildIdx)->getKind() ==
NodeKind::DifferentiableFunctionType) {
Expand All @@ -783,6 +763,27 @@ class TypeDecoder {
++firstChildIdx;
}

bool isThrow = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::ThrowsAnnotation) {
isThrow = true;
++firstChildIdx;
}

bool isSendable = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::ConcurrentFunctionType) {
isSendable = true;
++firstChildIdx;
}

bool isAsync = false;
if (Node->getChild(firstChildIdx)->getKind()
== NodeKind::AsyncAnnotation) {
isAsync = true;
++firstChildIdx;
}

flags = flags.withConcurrent(isSendable)
.withAsync(isAsync).withThrows(isThrow)
.withDifferentiable(diffKind.isDifferentiable());
Expand Down Expand Up @@ -1370,33 +1371,44 @@ class TypeDecoder {
FunctionParam<BuiltType> &param) -> bool {
Demangle::NodePointer node = typeNode;

auto setOwnership = [&](ValueOwnership ownership) {
param.setValueOwnership(ownership);
node = node->getFirstChild();
hasParamFlags = true;
};
switch (node->getKind()) {
case NodeKind::InOut:
setOwnership(ValueOwnership::InOut);
break;
bool recurse = true;
while (recurse) {
switch (node->getKind()) {
case NodeKind::InOut:
param.setValueOwnership(ValueOwnership::InOut);
node = node->getFirstChild();
hasParamFlags = true;
break;

case NodeKind::Shared:
setOwnership(ValueOwnership::Shared);
break;
case NodeKind::Shared:
param.setValueOwnership(ValueOwnership::Shared);
node = node->getFirstChild();
hasParamFlags = true;
break;

case NodeKind::Owned:
setOwnership(ValueOwnership::Owned);
break;
case NodeKind::Owned:
param.setValueOwnership(ValueOwnership::Owned);
node = node->getFirstChild();
hasParamFlags = true;
break;

case NodeKind::AutoClosureType:
case NodeKind::EscapingAutoClosureType: {
param.setAutoClosure();
hasParamFlags = true;
break;
}
case NodeKind::NoDerivative:
param.setNoDerivative();
node = node->getFirstChild();
hasParamFlags = true;
break;

default:
break;
case NodeKind::AutoClosureType:
case NodeKind::EscapingAutoClosureType:
param.setAutoClosure();
hasParamFlags = true;
recurse = false;
break;

default:
recurse = false;
break;
}
}

auto paramType = decodeMangledType(node);
Expand Down
15 changes: 9 additions & 6 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2448,25 +2448,25 @@ void ASTMangler::appendFunctionSignature(AnyFunctionType *fn,
appendFunctionResultType(fn->getResult(), forDecl);
appendFunctionInputType(fn->getParams(), forDecl);
if (fn->isAsync() || functionMangling == AsyncHandlerBodyMangling)
appendOperator("Y");
appendOperator("Ya");
if (fn->isSendable())
appendOperator("J");
appendOperator("Yb");
if (fn->isThrowing())
appendOperator("K");
switch (auto diffKind = fn->getDifferentiabilityKind()) {
case DifferentiabilityKind::NonDifferentiable:
break;
case DifferentiabilityKind::Forward:
appendOperator("jf");
appendOperator("Yjf");
break;
case DifferentiabilityKind::Reverse:
appendOperator("jr");
appendOperator("Yjr");
break;
case DifferentiabilityKind::Normal:
appendOperator("jd");
appendOperator("Yjd");
break;
case DifferentiabilityKind::Linear:
appendOperator("jl");
appendOperator("Yjl");
break;
}
}
Expand Down Expand Up @@ -2541,6 +2541,9 @@ void ASTMangler::appendTypeListElement(Identifier name, Type elementType,
else
appendType(elementType, forDecl);

if (flags.isNoDerivative()) {
appendOperator("Yk");
}
switch (flags.getValueOwnership()) {
case ValueOwnership::Default:
/* nothing */
Expand Down
28 changes: 21 additions & 7 deletions lib/Demangling/Demangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,22 @@ NodePointer Demangler::demangleSymbolicReference(unsigned char rawKind) {
return resolved;
}

NodePointer Demangler::demangleTypeAnnotation() {
switch (char c2 = nextChar()) {
case 'a':
return createNode(Node::Kind::AsyncAnnotation);
case 'b':
return createNode(Node::Kind::ConcurrentFunctionType);
case 'j':
return demangleDifferentiableFunctionType();
case 'k':
return createType(
createWithChild(Node::Kind::NoDerivative, popTypeAndGetChild()));
default:
return nullptr;
}
}

NodePointer Demangler::demangleOperator() {
recur:
switch (unsigned char c = nextChar()) {
Expand Down Expand Up @@ -767,7 +783,6 @@ NodePointer Demangler::demangleOperator() {
}

case 'I': return demangleImplFunctionType();
case 'J': return createNode(Node::Kind::ConcurrentFunctionType);
case 'K': return createNode(Node::Kind::ThrowsAnnotation);
case 'L': return demangleLocalIdentifier();
case 'M': return demangleMetatype();
Expand All @@ -782,7 +797,7 @@ NodePointer Demangler::demangleOperator() {
case 'V': return demangleAnyGenericType(Node::Kind::Structure);
case 'W': return demangleWitness();
case 'X': return demangleSpecialType();
case 'Y': return createNode(Node::Kind::AsyncAnnotation);
case 'Y': return demangleTypeAnnotation();
case 'Z': return createWithChild(Node::Kind::Static, popNode(isEntity));
case 'a': return demangleAnyGenericType(Node::Kind::TypeAlias);
case 'c': return popFunctionType(Node::Kind::FunctionType);
Expand All @@ -792,7 +807,6 @@ NodePointer Demangler::demangleOperator() {
case 'h': return createType(createWithChild(Node::Kind::Shared,
popTypeAndGetChild()));
case 'i': return demangleSubscript();
case 'j': return demangleDifferentiableFunctionType();
case 'l': return demangleGenericSignature(/*hasParamCounts*/ false);
case 'm': return createType(createWithChild(Node::Kind::Metatype,
popNode(Node::Kind::Type)));
Expand Down Expand Up @@ -1254,10 +1268,10 @@ NodePointer Demangler::popFunctionType(Node::Kind kind, bool hasClangType) {
ClangType = demangleClangType();
}
addChild(FuncType, ClangType);
addChild(FuncType, popNode(Node::Kind::DifferentiableFunctionType));
addChild(FuncType, popNode(Node::Kind::ThrowsAnnotation));
addChild(FuncType, popNode(Node::Kind::ConcurrentFunctionType));
addChild(FuncType, popNode(Node::Kind::AsyncAnnotation));
addChild(FuncType, popNode(Node::Kind::DifferentiableFunctionType));

FuncType = addChild(FuncType, popFunctionParams(Node::Kind::ArgumentTuple));
FuncType = addChild(FuncType, popFunctionParams(Node::Kind::ReturnType));
Expand Down Expand Up @@ -1290,6 +1304,9 @@ NodePointer Demangler::popFunctionParamLabels(NodePointer Type) {
return nullptr;

unsigned FirstChildIdx = 0;
if (FuncType->getChild(FirstChildIdx)->getKind()
== Node::Kind::DifferentiableFunctionType)
++FirstChildIdx;
if (FuncType->getChild(FirstChildIdx)->getKind()
== Node::Kind::ThrowsAnnotation)
++FirstChildIdx;
Expand All @@ -1299,9 +1316,6 @@ NodePointer Demangler::popFunctionParamLabels(NodePointer Type) {
if (FuncType->getChild(FirstChildIdx)->getKind()
== Node::Kind::AsyncAnnotation)
++FirstChildIdx;
if (FuncType->getChild(FirstChildIdx)->getKind()
== Node::Kind::DifferentiableFunctionType)
++FirstChildIdx;
auto ParameterType = FuncType->getChild(FirstChildIdx);

assert(ParameterType->getKind() == Node::Kind::ArgumentTuple);
Expand Down
17 changes: 11 additions & 6 deletions lib/Demangling/NodePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ class NodePrinter {
case Node::Kind::AutoDiffSubsetParametersThunk:
case Node::Kind::AutoDiffFunctionKind:
case Node::Kind::DifferentiabilityWitness:
case Node::Kind::NoDerivative:
case Node::Kind::IndexSubset:
case Node::Kind::AsyncAwaitResumePartialFunction:
case Node::Kind::AsyncSuspendResumePartialFunction:
Expand Down Expand Up @@ -808,6 +809,12 @@ class NodePrinter {
unsigned startIndex = 0;
bool isSendable = false, isAsync = false, isThrows = false;
auto diffKind = MangledDifferentiabilityKind::NonDifferentiable;
if (node->getChild(startIndex)->getKind() ==
Node::Kind::DifferentiableFunctionType) {
diffKind =
(MangledDifferentiabilityKind)node->getChild(startIndex)->getIndex();
++startIndex;
}
if (node->getChild(startIndex)->getKind() == Node::Kind::ClangType) {
// handled earlier
++startIndex;
Expand All @@ -825,12 +832,6 @@ class NodePrinter {
++startIndex;
isAsync = true;
}
if (node->getChild(startIndex)->getKind() ==
Node::Kind::DifferentiableFunctionType) {
diffKind =
(MangledDifferentiabilityKind)node->getChild(startIndex)->getIndex();
++startIndex;
}

switch (diffKind) {
case MangledDifferentiabilityKind::Forward:
Expand Down Expand Up @@ -1421,6 +1422,10 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
Printer << "__owned ";
print(Node->getChild(0));
return nullptr;
case Node::Kind::NoDerivative:
Printer << "@noDerivative ";
print(Node->getChild(0));
return nullptr;
case Node::Kind::NonObjCAttribute:
Printer << "@nonobjc ";
return nullptr;
Expand Down
8 changes: 8 additions & 0 deletions lib/Demangling/OldDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,14 @@ class OldDemangler {
inout->addChild(type, Factory);
return inout;
}
if (c == 'k') {
auto noDerivative = Factory.createNode(Node::Kind::NoDerivative);
auto type = demangleTypeImpl();
if (!type)
return nullptr;
noDerivative->addChild(type, Factory);
return noDerivative;
}
if (c == 'S') {
return demangleSubstitutionIndex();
}
Expand Down
5 changes: 5 additions & 0 deletions lib/Demangling/OldRemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,11 @@ void Remangler::mangleInOut(Node *node) {
mangleSingleChildNode(node); // type
}

void Remangler::mangleNoDerivative(Node *node) {
Buffer << 'k';
mangleSingleChildNode(node); // type
}

void Remangler::mangleTuple(Node *node) {
size_t NumElems = node->getNumChildren();
if (NumElems > 0 &&
Expand Down
Loading