Skip to content

Commit a8a05bc

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Mangle @noDerivative parameters. (swiftlang#31201)
Mangle `@noDerivative` parameters to fix type reconstruction errors. Resolves SR-12650. The new mangling is non-breaking. When differentiation supports multiple result indices and `@noDerivative` results are added, we can reuse some of this mangling support.
1 parent 1c00f54 commit a8a05bc

File tree

12 files changed

+192
-9
lines changed

12 files changed

+192
-9
lines changed

docs/ABI/Mangling.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ mangled in to disambiguate.
589589
impl-function-type ::= type* 'I' FUNC-ATTRIBUTES '_'
590590
impl-function-type ::= type* generic-signature 'I' FUNC-ATTRIBUTES '_'
591591

592-
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? PARAM-CONVENTION* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
592+
FUNC-ATTRIBUTES ::= PATTERN-SUBS? INVOCATION-SUBS? PSEUDO-GENERIC? CALLEE-ESCAPE? DIFFERENTIABILITY-KIND? CALLEE-CONVENTION FUNC-REPRESENTATION? COROUTINE-KIND? (PARAM-CONVENTION PARAM-DIFFERENTIABILITY?)* RESULT-CONVENTION* ('Y' PARAM-CONVENTION)* ('z' RESULT-CONVENTION)?
593593

594594
PATTERN-SUBS ::= 's' // has pattern substitutions
595595
INVOCATION-SUB ::= 'I' // has invocation substitutions
@@ -626,6 +626,8 @@ mangled in to disambiguate.
626626
PARAM-CONVENTION ::= 'g' // direct guaranteed
627627
PARAM-CONVENTION ::= 'e' // direct deallocating
628628

629+
PARAM-DIFFERENTIABILITY ::= 'w' // @noDerivative
630+
629631
RESULT-CONVENTION ::= 'r' // indirect
630632
RESULT-CONVENTION ::= 'o' // owned
631633
RESULT-CONVENTION ::= 'd' // unowned

include/swift/Demangling/DemangleNodes.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ NODE(ImplDifferentiable)
117117
NODE(ImplLinear)
118118
NODE(ImplEscaping)
119119
NODE(ImplConvention)
120+
NODE(ImplDifferentiability)
120121
NODE(ImplFunctionAttribute)
121122
NODE(ImplFunctionType)
122123
NODE(ImplInvocationSubstitutions)

include/swift/Demangling/Demangler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class Demangler : public NodeFactory {
518518
NodePointer demangleInitializer();
519519
NodePointer demangleImplParamConvention(Node::Kind ConvKind);
520520
NodePointer demangleImplResultConvention(Node::Kind ConvKind);
521+
NodePointer demangleImplDifferentiability();
521522
NodePointer demangleImplFunctionType();
522523
NodePointer demangleMetatype();
523524
NodePointer demanglePrivateContextDescriptor();

include/swift/Demangling/TypeDecoder.h

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,31 @@ enum class ImplParameterConvention {
8888
Direct_Guaranteed,
8989
};
9090

91+
enum class ImplParameterDifferentiability {
92+
DifferentiableOrNotApplicable,
93+
NotDifferentiable
94+
};
95+
96+
static inline Optional<ImplParameterDifferentiability>
97+
getDifferentiabilityFromString(StringRef string) {
98+
if (string.empty())
99+
return ImplParameterDifferentiability::DifferentiableOrNotApplicable;
100+
if (string == "@noDerivative")
101+
return ImplParameterDifferentiability::NotDifferentiable;
102+
return None;
103+
}
104+
91105
/// Describe a lowered function parameter, parameterized on the type
92106
/// representation.
93107
template <typename BuiltType>
94108
class ImplFunctionParam {
95109
ImplParameterConvention Convention;
110+
ImplParameterDifferentiability Differentiability;
96111
BuiltType Type;
97112

98113
public:
99114
using ConventionType = ImplParameterConvention;
115+
using DifferentiabilityType = ImplParameterDifferentiability;
100116

101117
static Optional<ConventionType>
102118
getConventionFromString(StringRef conventionString) {
@@ -120,11 +136,16 @@ class ImplFunctionParam {
120136
return None;
121137
}
122138

123-
ImplFunctionParam(ImplParameterConvention convention, BuiltType type)
124-
: Convention(convention), Type(type) {}
139+
ImplFunctionParam(ImplParameterConvention convention,
140+
ImplParameterDifferentiability diffKind, BuiltType type)
141+
: Convention(convention), Differentiability(diffKind), Type(type) {}
125142

126143
ImplParameterConvention getConvention() const { return Convention; }
127144

145+
ImplParameterDifferentiability getDifferentiability() const {
146+
return Differentiability;
147+
}
148+
128149
BuiltType getType() const { return Type; }
129150
};
130151

@@ -614,10 +635,8 @@ class TypeDecoder {
614635
ImplFunctionDifferentiabilityKind::Linear);
615636
} else if (child->getKind() == NodeKind::ImplEscaping) {
616637
flags = flags.withEscaping();
617-
} else if (child->getKind() == NodeKind::ImplEscaping) {
618-
flags = flags.withEscaping();
619638
} else if (child->getKind() == NodeKind::ImplParameter) {
620-
if (decodeImplFunctionPart(child, parameters))
639+
if (decodeImplFunctionParam(child, parameters))
621640
return BuiltType();
622641
} else if (child->getKind() == NodeKind::ImplResult) {
623642
if (decodeImplFunctionPart(child, results))
@@ -897,6 +916,45 @@ class TypeDecoder {
897916
return false;
898917
}
899918

919+
bool decodeImplFunctionParam(
920+
Demangle::NodePointer node,
921+
SmallVectorImpl<ImplFunctionParam<BuiltType>> &results) {
922+
// Children: `convention, differentiability?, type`
923+
if (node->getNumChildren() != 2 && node->getNumChildren() != 3)
924+
return true;
925+
926+
auto *conventionNode = node->getChild(0);
927+
auto *typeNode = node->getLastChild();
928+
if (conventionNode->getKind() != Node::Kind::ImplConvention ||
929+
typeNode->getKind() != Node::Kind::Type)
930+
return true;
931+
932+
StringRef conventionString = conventionNode->getText();
933+
auto convention =
934+
ImplFunctionParam<BuiltType>::getConventionFromString(conventionString);
935+
if (!convention)
936+
return true;
937+
BuiltType type = decodeMangledType(typeNode);
938+
if (!type)
939+
return true;
940+
941+
auto diffKind =
942+
ImplParameterDifferentiability::DifferentiableOrNotApplicable;
943+
if (node->getNumChildren() == 3) {
944+
auto diffKindNode = node->getChild(1);
945+
if (diffKindNode->getKind() != Node::Kind::ImplDifferentiability)
946+
return true;
947+
auto optDiffKind =
948+
getDifferentiabilityFromString(diffKindNode->getText());
949+
if (!optDiffKind)
950+
return true;
951+
diffKind = *optDiffKind;
952+
}
953+
954+
results.emplace_back(*convention, diffKind, type);
955+
return false;
956+
}
957+
900958
bool decodeMangledTypeDecl(Demangle::NodePointer node,
901959
BuiltTypeDecl &typeDecl,
902960
BuiltType &parent,

lib/AST/ASTDemangler.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,16 @@ getParameterConvention(ImplParameterConvention conv) {
444444
llvm_unreachable("covered switch");
445445
}
446446

447+
static SILParameterDifferentiability
448+
getParameterDifferentiability(ImplParameterDifferentiability diffKind) {
449+
switch (diffKind) {
450+
case ImplParameterDifferentiability::DifferentiableOrNotApplicable:
451+
return SILParameterDifferentiability::DifferentiableOrNotApplicable;
452+
case ImplParameterDifferentiability::NotDifferentiable:
453+
return SILParameterDifferentiability::NotDifferentiable;
454+
}
455+
}
456+
447457
static ResultConvention getResultConvention(ImplResultConvention conv) {
448458
switch (conv) {
449459
case Demangle::ImplResultConvention::Indirect:
@@ -526,7 +536,8 @@ Type ASTBuilder::createImplFunctionType(
526536
for (const auto &param : params) {
527537
auto type = param.getType()->getCanonicalType();
528538
auto conv = getParameterConvention(param.getConvention());
529-
funcParams.emplace_back(type, conv);
539+
auto diffKind = getParameterDifferentiability(param.getDifferentiability());
540+
funcParams.emplace_back(type, conv, diffKind);
530541
}
531542

532543
for (const auto &result : results) {

lib/AST/ASTMangler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,17 @@ static char getParamConvention(ParameterConvention conv) {
15691569
llvm_unreachable("bad parameter convention");
15701570
};
15711571

1572+
static Optional<char>
1573+
getParamDifferentiability(SILParameterDifferentiability diffKind) {
1574+
switch (diffKind) {
1575+
case swift::SILParameterDifferentiability::DifferentiableOrNotApplicable:
1576+
return None;
1577+
case swift::SILParameterDifferentiability::NotDifferentiable:
1578+
return 'w';
1579+
}
1580+
llvm_unreachable("bad parameter convention");
1581+
};
1582+
15721583
static char getResultConvention(ResultConvention conv) {
15731584
switch (conv) {
15741585
case ResultConvention::Indirect: return 'r';
@@ -1658,6 +1669,8 @@ void ASTMangler::appendImplFunctionType(SILFunctionType *fn) {
16581669
// Mangle the parameters.
16591670
for (auto param : fn->getParameters()) {
16601671
OpArgs.push_back(getParamConvention(param.getConvention()));
1672+
if (auto diffKind = getParamDifferentiability(param.getDifferentiability()))
1673+
OpArgs.push_back(*diffKind);
16611674
appendType(param.getInterfaceType());
16621675
}
16631676

lib/Demangling/Demangler.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,14 @@ NodePointer Demangler::demangleImplResultConvention(Node::Kind ConvKind) {
17321732
createNode(Node::Kind::ImplConvention, attr));
17331733
}
17341734

1735+
NodePointer Demangler::demangleImplDifferentiability() {
1736+
// Empty string represents default differentiability.
1737+
const char *attr = "";
1738+
if (nextIf('w'))
1739+
attr = "@noDerivative";
1740+
return createNode(Node::Kind::ImplDifferentiability, attr);
1741+
}
1742+
17351743
NodePointer Demangler::demangleImplFunctionType() {
17361744
NodePointer type = createNode(Node::Kind::ImplFunctionType);
17371745

@@ -1817,8 +1825,10 @@ NodePointer Demangler::demangleImplFunctionType() {
18171825

18181826
int NumTypesToAdd = 0;
18191827
while (NodePointer Param =
1820-
demangleImplParamConvention(Node::Kind::ImplParameter)) {
1828+
demangleImplParamConvention(Node::Kind::ImplParameter)) {
18211829
type = addChild(type, Param);
1830+
if (NodePointer Diff = demangleImplDifferentiability())
1831+
Param = addChild(Param, Diff);
18221832
NumTypesToAdd++;
18231833
}
18241834
while (NodePointer Result = demangleImplResultConvention(

lib/Demangling/NodePrinter.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ class NodePrinter {
394394
case Node::Kind::ImplLinear:
395395
case Node::Kind::ImplEscaping:
396396
case Node::Kind::ImplConvention:
397+
case Node::Kind::ImplDifferentiability:
397398
case Node::Kind::ImplFunctionAttribute:
398399
case Node::Kind::ImplFunctionType:
399400
case Node::Kind::ImplInvocationSubstitutions:
@@ -2060,6 +2061,13 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
20602061
case Node::Kind::ImplConvention:
20612062
Printer << Node->getText();
20622063
return nullptr;
2064+
case Node::Kind::ImplDifferentiability:
2065+
// Skip if text is empty.
2066+
if (Node->getText().empty())
2067+
return nullptr;
2068+
// Otherwise, print with trailing space.
2069+
Printer << Node->getText() << ' ';
2070+
return nullptr;
20632071
case Node::Kind::ImplFunctionAttribute:
20642072
Printer << Node->getText();
20652073
return nullptr;
@@ -2072,6 +2080,16 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
20722080
printChildren(Node, " ");
20732081
return nullptr;
20742082
case Node::Kind::ImplParameter:
2083+
// Children: `convention, differentiability?, type`
2084+
// Print convention.
2085+
print(Node->getChild(0));
2086+
Printer << " ";
2087+
// Print differentiability, if it exists.
2088+
if (Node->getNumChildren() == 3)
2089+
print(Node->getChild(1));
2090+
// Print type.
2091+
print(Node->getLastChild());
2092+
return nullptr;
20752093
case Node::Kind::ImplResult:
20762094
printChildren(Node, " ");
20772095
return nullptr;

lib/Demangling/OldRemangler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,19 @@ void Remangler::mangleImplConvention(Node *node) {
13261326
}
13271327
}
13281328

1329+
void Remangler::mangleImplDifferentiability(Node *node) {
1330+
assert(node->getKind() == Node::Kind::ImplDifferentiability);
1331+
StringRef text = node->getText();
1332+
// Empty string represents default differentiability.
1333+
if (text.empty())
1334+
return;
1335+
if (text == "@noDerivative") {
1336+
Buffer << 'w';
1337+
return;
1338+
}
1339+
unreachable("Invalid impl differentiability");
1340+
}
1341+
13291342
void Remangler::mangleDynamicSelf(Node *node) {
13301343
Buffer << 'D';
13311344
mangleSingleChildNode(node); // type

lib/Demangling/Remangler.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,18 @@ void Remangler::mangleImplConvention(Node *node) {
14201420
Buffer << ConvCh;
14211421
}
14221422

1423+
void Remangler::mangleImplDifferentiability(Node *node) {
1424+
assert(node->hasText());
1425+
// Empty string represents default differentiability.
1426+
if (node->getText().empty())
1427+
return;
1428+
char diffChar = llvm::StringSwitch<char>(node->getText())
1429+
.Case("@noDerivative", 'w')
1430+
.Default(0);
1431+
assert(diffChar && "Invalid impl differentiability");
1432+
Buffer << diffChar;
1433+
}
1434+
14231435
void Remangler::mangleImplFunctionAttribute(Node *node) {
14241436
unreachable("handled inline");
14251437
}
@@ -1443,7 +1455,9 @@ void Remangler::mangleImplFunctionType(Node *node) {
14431455
case Node::Kind::ImplResult:
14441456
case Node::Kind::ImplYield:
14451457
case Node::Kind::ImplErrorResult:
1446-
mangleChildNode(Child, 1);
1458+
// Mangle type. Type should be the last child.
1459+
assert(Child->getNumChildren() == 2 || Child->getNumChildren() == 3);
1460+
mangle(Child->getLastChild());
14471461
break;
14481462
case Node::Kind::DependentPseudogenericSignature:
14491463
PseudoGeneric = "P";
@@ -1526,6 +1540,7 @@ void Remangler::mangleImplFunctionType(Node *node) {
15261540
Buffer << 'Y';
15271541
LLVM_FALLTHROUGH;
15281542
case Node::Kind::ImplParameter: {
1543+
// Mangle parameter convention.
15291544
char ConvCh =
15301545
llvm::StringSwitch<char>(Child->getFirstChild()->getText())
15311546
.Case("@in", 'i')
@@ -1540,6 +1555,9 @@ void Remangler::mangleImplFunctionType(Node *node) {
15401555
.Default(0);
15411556
assert(ConvCh && "invalid impl parameter convention");
15421557
Buffer << ConvCh;
1558+
// Mangle parameter differentiability, if it exists.
1559+
if (Child->getNumChildren() == 3)
1560+
mangleImplDifferentiability(Child->getChild(1));
15431561
break;
15441562
}
15451563
case Node::Kind::ImplErrorResult:
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: %target-build-swift -g %s
2+
3+
// SR-12650: IRGenDebugInfo type reconstruction crash because `@noDerivative`
4+
// parameters are not mangled.
5+
6+
import _Differentiation
7+
func id(_ x: Float, _ y: Float) -> Float { x }
8+
let transformed: @differentiable (Float, @noDerivative Float) -> Float = id
9+
10+
// Incorrect reconstructed type for $sS3fIedgyyd_D
11+
// Original type:
12+
// (sil_function_type type=@differentiable @callee_guaranteed (Float, @noDerivative Float) -> Float
13+
// (input=struct_type decl=Swift.(file).Float)
14+
// (input=struct_type decl=Swift.(file).Float)
15+
// (result=struct_type decl=Swift.(file).Float)
16+
// (substitution_map generic_signature=<nullptr>)
17+
// (substitution_map generic_signature=<nullptr>))
18+
// Reconstructed type:
19+
// (sil_function_type type=@differentiable @callee_guaranteed (Float, Float) -> Float
20+
// (input=struct_type decl=Swift.(file).Float)
21+
// (input=struct_type decl=Swift.(file).Float)
22+
// (result=struct_type decl=Swift.(file).Float)
23+
// (substitution_map generic_signature=<nullptr>)
24+
// (substitution_map generic_signature=<nullptr>))
25+
// Stack dump:
26+
// ...
27+
// 1. Swift version 5.3-dev (LLVM 803d1b184d, Swift 477af9f90d)
28+
// 2. While evaluating request IRGenSourceFileRequest(IR Generation for file "noderiv.swift")
29+
// 0 swift 0x00000001104c7ae8 llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 40
30+
// 1 swift 0x00000001104c6a68 llvm::sys::RunSignalHandlers() + 248
31+
// 2 swift 0x00000001104c80dd SignalHandler(int) + 285
32+
// 3 libsystem_platform.dylib 0x00007fff718335fd _sigtramp + 29
33+
// 4 libsystem_platform.dylib 000000000000000000 _sigtramp + 18446603338611739168
34+
// 5 libsystem_c.dylib 0x00007fff71709808 abort + 120
35+
// 6 swift 0x0000000110604152 (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) (.cold.20) + 146
36+
// 7 swift 0x000000010c24ab1e (anonymous namespace)::IRGenDebugInfoImpl::getOrCreateType(swift::irgen::DebugTypeInfo) + 3614
37+
// 8 swift 0x000000010c245437 swift::irgen::IRGenDebugInfo::emitGlobalVariableDeclaration(llvm::GlobalVariable*, llvm::StringRef, llvm::StringRef, swift::irgen::DebugTypeInfo, bool, bool, llvm::Optional<swift::SILLocation>) + 167

test/Demangle/Inputs/manglings.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,4 @@ $s17property_wrappers10WithTuplesV9fractionsSd_S2dtvpfP --> property wrapper bac
357357
$sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTOTA ---> {T:$sSo17OS_dispatch_queueC4sync7executeyyyXE_tFTO} partial apply forwarder for @nonobjc __C.OS_dispatch_queue.sync(execute: () -> ()) -> ()
358358
$sxq_Idgnr_D ---> @differentiable @callee_guaranteed (@in_guaranteed A) -> (@out B)
359359
$sxq_Ilgnr_D ---> @differentiable(linear) @callee_guaranteed (@in_guaranteed A) -> (@out B)
360+
$sS3fIedgyywd_D ---> @escaping @differentiable @callee_guaranteed (@unowned Swift.Float, @unowned @noDerivative Swift.Float) -> (@unowned Swift.Float)

0 commit comments

Comments
 (0)