Skip to content

Commit 19f227f

Browse files
authored
Revert "[AutoDiff] Robust mangling support for AD associated functions. (#26624)" (#26656)
This reverts #26624, which introduced TF-758: demangling crash regarding JVP/VJP mangling + generic specialization mangling. Add test to prevent regression.
1 parent c010f74 commit 19f227f

17 files changed

+111
-341
lines changed

include/swift/Demangling/DemangleNodes.def

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@ NODE(DefaultAssociatedTypeMetadataAccessor)
3535
NODE(AssociatedTypeWitnessTableAccessor)
3636
NODE(BaseWitnessTableAccessor)
3737
NODE(AutoClosureType)
38-
// SWIFT_ENABLE_TENSORFLOW
39-
NODE(AutoDiffParameterIndices)
40-
NODE(AutoDiffResultIndex)
41-
NODE(AutoDiffJVP)
42-
NODE(AutoDiffVJP)
43-
NODE(AutoDiffDifferential)
44-
NODE(AutoDiffPullback)
45-
// SWIFT_ENABLE_TENSORFLOW END
4638
NODE(BoundGenericClass)
4739
NODE(BoundGenericEnum)
4840
NODE(BoundGenericStructure)

lib/AST/ASTMangler.cpp

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -369,86 +369,51 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
369369
return finalize();
370370
}
371371

372-
// SWIFT_ENABLE_TENSORFLOW
373-
/// Get a `NodePointer` representing an autodiff associated function for the
374-
/// given original function mangled name, associated function node kind,
375-
/// and parameter indices.
376-
static NodePointer getMangledAutoDiffAssociatedFunctionNode(
377-
Demangler &D, StringRef name, Node::Kind assocFnNodeKind,
378-
const SILAutoDiffIndices &indices) {
379-
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
380-
auto topLevel = D.createNode(Node::Kind::Global);
381-
auto assocFn = D.createNode(assocFnNodeKind);
382-
topLevel->addChild(assocFn, D);
383-
384-
auto funcTopLevel = D.demangleSymbol(name);
385-
// If original function name cannot be demangled (e.g. it has a custom name
386-
// via `@_silgen_name`), add it as an identifier node.
387-
if (!funcTopLevel) {
388-
funcTopLevel = D.createNode(Node::Kind::Global);
389-
funcTopLevel->addChild(D.createNode(Node::Kind::Identifier, name), D);
390-
}
391-
assert(funcTopLevel);
392-
for (auto funcChild : *funcTopLevel)
393-
assocFn->addChild(funcChild, D);
394-
395-
auto paramIndices =
396-
D.createNode(Node::Kind::AutoDiffParameterIndices);
397-
for (unsigned i : indices.parameters->getIndices()) {
398-
auto paramIdx = D.createNode(Node::Kind::Index, i);
399-
paramIndices->addChild(paramIdx, D);
400-
}
401-
assocFn->addChild(paramIndices, D);
402-
auto resultIdx =
403-
D.createNode(Node::Kind::AutoDiffResultIndex, indices.source);
404-
assocFn->addChild(resultIdx, D);
405-
return topLevel;
406-
}
407-
408372
std::string ASTMangler::mangleAutoDiffAssociatedFunctionHelper(
409373
StringRef name, AutoDiffAssociatedFunctionKind kind,
410374
const SILAutoDiffIndices &indices) {
375+
// TODO(TF-20): Make the mangling scheme robust.
376+
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
411377
beginManglingWithoutPrefix();
412378

413-
Demangler D;
414-
Node::Kind assocFnNodeKind;
379+
Buffer << "AD__" << name << '_';
415380
switch (kind) {
416381
case AutoDiffAssociatedFunctionKind::JVP:
417-
assocFnNodeKind = Node::Kind::AutoDiffJVP;
382+
Buffer << "_jvp_";
418383
break;
419384
case AutoDiffAssociatedFunctionKind::VJP:
420-
assocFnNodeKind = Node::Kind::AutoDiffVJP;
385+
Buffer << "_vjp_";
421386
break;
422387
}
423-
auto result = getMangledAutoDiffAssociatedFunctionNode(
424-
D, name, assocFnNodeKind, indices);
425-
auto mangled = Demangle::mangleNode(result);
426-
verify(mangled);
427-
return mangled;
388+
Buffer << indices.mangle();
389+
390+
auto result = Storage.str().str();
391+
Storage.clear();
392+
return result;
428393
}
429394

430395
std::string ASTMangler::mangleAutoDiffLinearMapHelper(
431396
StringRef name, AutoDiffLinearMapKind kind,
432397
const SILAutoDiffIndices &indices) {
398+
// TODO(TF-20): Make the mangling scheme robust.
399+
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
433400
beginManglingWithoutPrefix();
434401

435-
Demangler D;
436-
Node::Kind assocFnNodeKind;
402+
Buffer << "AD__" << name << '_';
437403
switch (kind) {
438404
case AutoDiffLinearMapKind::Differential:
439-
assocFnNodeKind = Node::Kind::AutoDiffDifferential;
405+
Buffer << "_differential_";
440406
break;
441407
case AutoDiffLinearMapKind::Pullback:
442-
assocFnNodeKind = Node::Kind::AutoDiffPullback;
408+
Buffer << "_pullback_";
443409
break;
444410
}
445-
auto result = getMangledAutoDiffAssociatedFunctionNode(
446-
D, name, assocFnNodeKind, indices);
447-
auto mangled = Demangle::mangleNode(result);
448-
verify(mangled);
449-
return mangled;
411+
Buffer << indices.mangle();
412+
413+
auto result = Storage.str().str();
414+
Storage.clear();
415+
return result;
450416
}
451-
// SWIFT_ENABLE_TENSORFLOW END
452417

453418
std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) {
454419
PrettyStackTraceType prettyStackTrace(Ty->getASTContext(),

lib/Demangling/Demangler.cpp

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,39 +2138,6 @@ NodePointer Demangler::demangleThunkOrSpecialization() {
21382138
addChild(Thunk, popNode(Node::Kind::Type));
21392139
return Thunk;
21402140
}
2141-
// SWIFT_ENABLE_TENSORFLOW
2142-
case 'z':
2143-
case 'Z':
2144-
case 'u':
2145-
case 'U': {
2146-
// Create node for autodiff associated function.
2147-
Node::Kind assocFnKind;
2148-
if (c == 'z') assocFnKind = Node::Kind::AutoDiffJVP;
2149-
else if (c == 'Z') assocFnKind = Node::Kind::AutoDiffVJP;
2150-
else if (c == 'u') assocFnKind = Node::Kind::AutoDiffDifferential;
2151-
else if (c == 'U') assocFnKind = Node::Kind::AutoDiffPullback;
2152-
else return nullptr;
2153-
NodePointer assocFn = createNode(assocFnKind);
2154-
addChild(assocFn, popNode());
2155-
// Demangle parameter indices.
2156-
auto paramIndices = createNode(Node::Kind::AutoDiffParameterIndices);
2157-
nextIf('p');
2158-
while (true) {
2159-
auto index = demangleNatural();
2160-
if (index < 0)
2161-
break;
2162-
paramIndices->addChild(createNode(Node::Kind::Index, index), *this);
2163-
nextIf('_');
2164-
}
2165-
// Demangle result index.
2166-
nextIf('r');
2167-
addChild(assocFn, paramIndices);
2168-
int resultIdx = demangleNatural();
2169-
auto resultIndex = createNode(Node::Kind::AutoDiffResultIndex, resultIdx);
2170-
addChild(assocFn, resultIndex);
2171-
return assocFn;
2172-
}
2173-
// SWIFT_ENABLE_TENSORFLOW END
21742141
case 'g':
21752142
return demangleGenericSpecialization(Node::Kind::GenericSpecialization);
21762143
case 'G':
@@ -2750,13 +2717,15 @@ NodePointer Demangler::demangleSpecialType() {
27502717
// SWIFT_ENABLE_TENSORFLOW
27512718
case 'F':
27522719
return popFunctionType(Node::Kind::DifferentiableFunctionType);
2720+
// SWIFT_ENABLE_TENSORFLOW
27532721
case 'G':
27542722
return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType);
2723+
// SWIFT_ENABLE_TENSORFLOW
27552724
case 'H':
27562725
return popFunctionType(Node::Kind::LinearFunctionType);
2726+
// SWIFT_ENABLE_TENSORFLOW
27572727
case 'I':
27582728
return popFunctionType(Node::Kind::EscapingLinearFunctionType);
2759-
// SWIFT_ENABLE_TENSORFLOW END
27602729
case 'o':
27612730
return createType(createWithChild(Node::Kind::Unowned,
27622731
popNode(Node::Kind::Type)));

lib/Demangling/NodePrinter.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,6 @@ class NodePrinter {
330330
case Node::Kind::AssociatedTypeMetadataAccessor:
331331
case Node::Kind::AssociatedTypeWitnessTableAccessor:
332332
case Node::Kind::AutoClosureType:
333-
// SWIFT_ENABLE_TENSORFLOW
334-
case Node::Kind::AutoDiffParameterIndices:
335-
case Node::Kind::AutoDiffResultIndex:
336-
case Node::Kind::AutoDiffJVP:
337-
case Node::Kind::AutoDiffVJP:
338-
case Node::Kind::AutoDiffDifferential:
339-
case Node::Kind::AutoDiffPullback:
340-
// SWIFT_ENABLE_TENSORFLOW END
341333
case Node::Kind::BaseConformanceDescriptor:
342334
case Node::Kind::BaseWitnessTableAccessor:
343335
case Node::Kind::ClassMetadataBaseOffset:
@@ -673,25 +665,6 @@ class NodePrinter {
673665
}
674666
}
675667

676-
// SWIFT_ENABLE_TENSORFLOW
677-
void printAutoDiffAssociatedFunction(NodePointer Node) {
678-
if (Node->getKind() == Node::Kind::AutoDiffJVP)
679-
Printer << "JVP ";
680-
else if (Node->getKind() == Node::Kind::AutoDiffVJP)
681-
Printer << "VJP ";
682-
else if (Node->getKind() == Node::Kind::AutoDiffDifferential)
683-
Printer << "differential ";
684-
else if (Node->getKind() == Node::Kind::AutoDiffPullback)
685-
Printer << "pullback ";
686-
else
687-
assert(false && "Unknown autodiff associated function kind");
688-
print(Node->getChild(1)); // wrt param indices
689-
print(Node->getChild(2)); // result index
690-
Printer << "for ";
691-
print(Node->getChild(0)); // original function
692-
}
693-
// SWIFT_ENABLE_TENSORFLOW END
694-
695668
NodePointer getChildIf(NodePointer Node, Node::Kind Kind) {
696669
auto result =
697670
std::find_if(Node->begin(), Node->end(), [&](NodePointer child) {
@@ -1653,28 +1626,6 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
16531626
print(Node->getChild(idx));
16541627
return nullptr;
16551628
}
1656-
// SWIFT_ENABLE_TENSORFLOW
1657-
case Node::Kind::AutoDiffParameterIndices: {
1658-
Printer << "wrt ";
1659-
interleave(Node->begin(), Node->end(),
1660-
[&](NodePointer child) {
1661-
Printer << child->getIndex();
1662-
}, [&]() { Printer << ", "; });
1663-
Printer << ' ';
1664-
return nullptr;
1665-
}
1666-
case Node::Kind::AutoDiffResultIndex: {
1667-
Printer << "source " << Node->getIndex() << ' ';
1668-
return nullptr;
1669-
}
1670-
case Node::Kind::AutoDiffJVP:
1671-
case Node::Kind::AutoDiffVJP:
1672-
case Node::Kind::AutoDiffDifferential:
1673-
case Node::Kind::AutoDiffPullback: {
1674-
printAutoDiffAssociatedFunction(Node);
1675-
return nullptr;
1676-
}
1677-
// SWIFT_ENABLE_TENSORFLOW END
16781629
case Node::Kind::MergedFunction:
16791630
if (!Options.ShortenThunk) {
16801631
Printer << "merged ";

lib/Demangling/OldRemangler.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -730,32 +730,6 @@ void Remangler::mangleReabstractionThunk(Node *node) {
730730
Buffer << "<reabstraction-thunk>";
731731
}
732732

733-
// SWIFT_ENABLE_TENSORFLOW
734-
void Remangler::mangleAutoDiffParameterIndices(Node *node) {
735-
Buffer << "<autodiff-parameter-indices>";
736-
}
737-
738-
void Remangler::mangleAutoDiffResultIndex(Node *node) {
739-
Buffer << "<autodiff-result-index>";
740-
}
741-
742-
void Remangler::mangleAutoDiffJVP(Node *node) {
743-
Buffer << "<autodiff-jvp-function>";
744-
}
745-
746-
void Remangler::mangleAutoDiffVJP(Node *node) {
747-
Buffer << "<autodiff-vjp-function>";
748-
}
749-
750-
void Remangler::mangleAutoDiffDifferential(Node *node) {
751-
Buffer << "<autodiff-differential-function>";
752-
}
753-
754-
void Remangler::mangleAutoDiffPullback(Node *node) {
755-
Buffer << "<autodiff-pullback-function>";
756-
}
757-
// SWIFT_ENABLE_TENSORFLOW END
758-
759733
void Remangler::mangleProtocolSelfConformanceWitness(Node *node) {
760734
Buffer << "TS";
761735
mangleSingleChildNode(node); // entity

lib/Demangling/Remangler.cpp

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,6 @@ class Remangler : public RemanglerBase {
309309
void mangleAnyProtocolConformance(Node *node);
310310

311311
void mangleKeyPathThunkHelper(Node *node, StringRef op);
312-
// SWIFT_ENABLE_TENSORFLOW
313-
void mangleAutoDiffAssociatedFunctionHelper(Node *node, StringRef op);
314-
// SWIFT_ENABLE_TENSORFLOW END
315312

316313
#define NODE(ID) \
317314
void mangle##ID(Node *node);
@@ -1955,48 +1952,6 @@ void Remangler::mangleReabstractionThunkHelperWithSelf(Node *node) {
19551952
Buffer << "Ty";
19561953
}
19571954

1958-
// SWIFT_ENABLE_TENSORFLOW
1959-
void Remangler::mangleAutoDiffParameterIndices(Node *node) {
1960-
Buffer << 'p';
1961-
for (unsigned i = 0, n = node->getNumChildren(); i != n; ++i) {
1962-
auto child = node->getChild(i);
1963-
Buffer << child->getIndex();
1964-
if (i != n - 1)
1965-
Buffer << '_';
1966-
}
1967-
}
1968-
1969-
void Remangler::mangleAutoDiffResultIndex(Node *node) {
1970-
Buffer << 'r' << node->getIndex();
1971-
}
1972-
1973-
void
1974-
Remangler::mangleAutoDiffAssociatedFunctionHelper(Node *node, StringRef op) {
1975-
// TODO(TF-680): Mangle `[differentiable]` atttribute requirements as well.
1976-
assert(node->getNumChildren() == 3);
1977-
mangleChildNode(node, 0); // original function
1978-
Buffer << op;
1979-
mangleChildNode(node, 1); // wrt parameter indices
1980-
mangleChildNode(node, 2); // result index
1981-
}
1982-
1983-
void Remangler::mangleAutoDiffJVP(Node *node) {
1984-
mangleAutoDiffAssociatedFunctionHelper(node, "Tz");
1985-
}
1986-
1987-
void Remangler::mangleAutoDiffVJP(Node *node) {
1988-
mangleAutoDiffAssociatedFunctionHelper(node, "TZ");
1989-
}
1990-
1991-
void Remangler::mangleAutoDiffDifferential(Node *node) {
1992-
mangleAutoDiffAssociatedFunctionHelper(node, "Tu");
1993-
}
1994-
1995-
void Remangler::mangleAutoDiffPullback(Node *node) {
1996-
mangleAutoDiffAssociatedFunctionHelper(node, "TU");
1997-
}
1998-
// SWIFT_ENABLE_TENSORFLOW END
1999-
20001955
void Remangler::mangleReadAccessor(Node *node) {
20011956
mangleAbstractStorage(node->getFirstChild(), "r");
20021957
}

test/AutoDiff/closures.swift

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@ struct InoutAliasableCapture {
2626

2727
// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
2828
// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
29-
// CHECK: // function_ref JVP wrt 0 source 0 for capturesMutableSelf #1 (t:) in InoutAliasableCapture.foo()
30-
// CHECK: [[JVP:%.*]] = function_ref @$s8closures21InoutAliasableCaptureV3fooyyF19capturesMutableSelfL_1tS2f_tFTzp0r0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
29+
// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
3130
// CHECK-NOT: retain_value_addr [[SELF]]
3231
// CHECK-NOT: copy_addr [[SELF]]
3332
// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34-
// CHECK: // function_ref VJP wrt 0 source 0 for capturesMutableSelf #1 (t:) in InoutAliasableCapture.foo()
35-
// CHECK: [[VJP:%.*]] = function_ref @$s8closures21InoutAliasableCaptureV3fooyyF19capturesMutableSelfL_1tS2f_tFTZp0r0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
33+
// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
3634
// CHECK-NOT: retain_value_addr [[SELF]]
3735
// CHECK-NOT: copy_addr [[SELF]]
3836
// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
@@ -45,11 +43,10 @@ public func closureCaptureMutable() {
4543
}
4644
}
4745

48-
// CHECK-LABEL: // VJP wrt 0 source 0 for closure #1 in closureCaptureMutable()
49-
// CHECK-NEXT: @$s8closures21closureCaptureMutableyyFS2fcfU_TZp0r0
46+
// CHECK-LABEL: @AD__{{.*}}closureCaptureMutable{{.*}}___vjp_src_0_wrt_0
5047
// CHECK: bb0({{%.*}} : $Float, [[INOUT_ARG:%.*]] : ${ var Float }):
51-
// CHECK: [[PB:%.*]] = function_ref @$s8closures21closureCaptureMutableyyFS2fcfU_TUp0r0
52-
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[PB]]({{.*}})
48+
// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___pullback_src_0_wrt_0
49+
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})
5350

5451
// TF-30: VJP return value should match the return type.
5552
struct TF_30 : Differentiable {

0 commit comments

Comments
 (0)