Skip to content

Commit c010f74

Browse files
authored
[AutoDiff] Robust mangling support for AD associated functions. (#26624)
Add demangler/remangler support for AD associated functions: JVPs, VJPs, differentials, pullbacks. These functions now appear with clarifying annotations in SIL: ``` // VJP wrt 0 source 0 for foo(_:) sil hidden @$s8mangling3fooyS2fFTZp0r0 ``` The following components are mangled: - Original function name. - Parameter indices. - Result index. Resolves TF-679.
1 parent 4df23c8 commit c010f74

17 files changed

+341
-95
lines changed

include/swift/Demangling/DemangleNodes.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ NODE(DefaultAssociatedTypeMetadataAccessor)
3535
NODE(AssociatedTypeWitnessTableAccessor)
3636
NODE(BaseWitnessTableAccessor)
3737
NODE(AutoClosureType)
38+
// SWIFT_ENABLE_TENSORFLOW
39+
NODE(AutoDiffParameterIndices)
40+
NODE(AutoDiffResultIndex)
41+
NODE(AutoDiffJVP)
42+
NODE(AutoDiffVJP)
43+
NODE(AutoDiffDifferential)
44+
NODE(AutoDiffPullback)
45+
// SWIFT_ENABLE_TENSORFLOW END
3846
NODE(BoundGenericClass)
3947
NODE(BoundGenericEnum)
4048
NODE(BoundGenericStructure)

lib/AST/ASTMangler.cpp

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -369,51 +369,86 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
369369
return finalize();
370370
}
371371

372+
// SWIFT_ENABLE_TENSORFLOW
373+
/// Get a `NodePointer` representing an autodiff associated function for the
374+
/// given original function mangled name, associated function node kind,
375+
/// and parameter indices.
376+
static NodePointer getMangledAutoDiffAssociatedFunctionNode(
377+
Demangler &D, StringRef name, Node::Kind assocFnNodeKind,
378+
const SILAutoDiffIndices &indices) {
379+
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
380+
auto topLevel = D.createNode(Node::Kind::Global);
381+
auto assocFn = D.createNode(assocFnNodeKind);
382+
topLevel->addChild(assocFn, D);
383+
384+
auto funcTopLevel = D.demangleSymbol(name);
385+
// If original function name cannot be demangled (e.g. it has a custom name
386+
// via `@_silgen_name`), add it as an identifier node.
387+
if (!funcTopLevel) {
388+
funcTopLevel = D.createNode(Node::Kind::Global);
389+
funcTopLevel->addChild(D.createNode(Node::Kind::Identifier, name), D);
390+
}
391+
assert(funcTopLevel);
392+
for (auto funcChild : *funcTopLevel)
393+
assocFn->addChild(funcChild, D);
394+
395+
auto paramIndices =
396+
D.createNode(Node::Kind::AutoDiffParameterIndices);
397+
for (unsigned i : indices.parameters->getIndices()) {
398+
auto paramIdx = D.createNode(Node::Kind::Index, i);
399+
paramIndices->addChild(paramIdx, D);
400+
}
401+
assocFn->addChild(paramIndices, D);
402+
auto resultIdx =
403+
D.createNode(Node::Kind::AutoDiffResultIndex, indices.source);
404+
assocFn->addChild(resultIdx, D);
405+
return topLevel;
406+
}
407+
372408
std::string ASTMangler::mangleAutoDiffAssociatedFunctionHelper(
373409
StringRef name, AutoDiffAssociatedFunctionKind kind,
374410
const SILAutoDiffIndices &indices) {
375-
// TODO(TF-20): Make the mangling scheme robust.
376-
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
377411
beginManglingWithoutPrefix();
378412

379-
Buffer << "AD__" << name << '_';
413+
Demangler D;
414+
Node::Kind assocFnNodeKind;
380415
switch (kind) {
381416
case AutoDiffAssociatedFunctionKind::JVP:
382-
Buffer << "_jvp_";
417+
assocFnNodeKind = Node::Kind::AutoDiffJVP;
383418
break;
384419
case AutoDiffAssociatedFunctionKind::VJP:
385-
Buffer << "_vjp_";
420+
assocFnNodeKind = Node::Kind::AutoDiffVJP;
386421
break;
387422
}
388-
Buffer << indices.mangle();
389-
390-
auto result = Storage.str().str();
391-
Storage.clear();
392-
return result;
423+
auto result = getMangledAutoDiffAssociatedFunctionNode(
424+
D, name, assocFnNodeKind, indices);
425+
auto mangled = Demangle::mangleNode(result);
426+
verify(mangled);
427+
return mangled;
393428
}
394429

395430
std::string ASTMangler::mangleAutoDiffLinearMapHelper(
396431
StringRef name, AutoDiffLinearMapKind kind,
397432
const SILAutoDiffIndices &indices) {
398-
// TODO(TF-20): Make the mangling scheme robust.
399-
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
400433
beginManglingWithoutPrefix();
401434

402-
Buffer << "AD__" << name << '_';
435+
Demangler D;
436+
Node::Kind assocFnNodeKind;
403437
switch (kind) {
404438
case AutoDiffLinearMapKind::Differential:
405-
Buffer << "_differential_";
439+
assocFnNodeKind = Node::Kind::AutoDiffDifferential;
406440
break;
407441
case AutoDiffLinearMapKind::Pullback:
408-
Buffer << "_pullback_";
442+
assocFnNodeKind = Node::Kind::AutoDiffPullback;
409443
break;
410444
}
411-
Buffer << indices.mangle();
412-
413-
auto result = Storage.str().str();
414-
Storage.clear();
415-
return result;
445+
auto result = getMangledAutoDiffAssociatedFunctionNode(
446+
D, name, assocFnNodeKind, indices);
447+
auto mangled = Demangle::mangleNode(result);
448+
verify(mangled);
449+
return mangled;
416450
}
451+
// SWIFT_ENABLE_TENSORFLOW END
417452

418453
std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) {
419454
PrettyStackTraceType prettyStackTrace(Ty->getASTContext(),

lib/Demangling/Demangler.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,6 +2138,39 @@ NodePointer Demangler::demangleThunkOrSpecialization() {
21382138
addChild(Thunk, popNode(Node::Kind::Type));
21392139
return Thunk;
21402140
}
2141+
// SWIFT_ENABLE_TENSORFLOW
2142+
case 'z':
2143+
case 'Z':
2144+
case 'u':
2145+
case 'U': {
2146+
// Create node for autodiff associated function.
2147+
Node::Kind assocFnKind;
2148+
if (c == 'z') assocFnKind = Node::Kind::AutoDiffJVP;
2149+
else if (c == 'Z') assocFnKind = Node::Kind::AutoDiffVJP;
2150+
else if (c == 'u') assocFnKind = Node::Kind::AutoDiffDifferential;
2151+
else if (c == 'U') assocFnKind = Node::Kind::AutoDiffPullback;
2152+
else return nullptr;
2153+
NodePointer assocFn = createNode(assocFnKind);
2154+
addChild(assocFn, popNode());
2155+
// Demangle parameter indices.
2156+
auto paramIndices = createNode(Node::Kind::AutoDiffParameterIndices);
2157+
nextIf('p');
2158+
while (true) {
2159+
auto index = demangleNatural();
2160+
if (index < 0)
2161+
break;
2162+
paramIndices->addChild(createNode(Node::Kind::Index, index), *this);
2163+
nextIf('_');
2164+
}
2165+
// Demangle result index.
2166+
nextIf('r');
2167+
addChild(assocFn, paramIndices);
2168+
int resultIdx = demangleNatural();
2169+
auto resultIndex = createNode(Node::Kind::AutoDiffResultIndex, resultIdx);
2170+
addChild(assocFn, resultIndex);
2171+
return assocFn;
2172+
}
2173+
// SWIFT_ENABLE_TENSORFLOW END
21412174
case 'g':
21422175
return demangleGenericSpecialization(Node::Kind::GenericSpecialization);
21432176
case 'G':
@@ -2717,15 +2750,13 @@ NodePointer Demangler::demangleSpecialType() {
27172750
// SWIFT_ENABLE_TENSORFLOW
27182751
case 'F':
27192752
return popFunctionType(Node::Kind::DifferentiableFunctionType);
2720-
// SWIFT_ENABLE_TENSORFLOW
27212753
case 'G':
27222754
return popFunctionType(Node::Kind::EscapingDifferentiableFunctionType);
2723-
// SWIFT_ENABLE_TENSORFLOW
27242755
case 'H':
27252756
return popFunctionType(Node::Kind::LinearFunctionType);
2726-
// SWIFT_ENABLE_TENSORFLOW
27272757
case 'I':
27282758
return popFunctionType(Node::Kind::EscapingLinearFunctionType);
2759+
// SWIFT_ENABLE_TENSORFLOW END
27292760
case 'o':
27302761
return createType(createWithChild(Node::Kind::Unowned,
27312762
popNode(Node::Kind::Type)));

lib/Demangling/NodePrinter.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,14 @@ class NodePrinter {
330330
case Node::Kind::AssociatedTypeMetadataAccessor:
331331
case Node::Kind::AssociatedTypeWitnessTableAccessor:
332332
case Node::Kind::AutoClosureType:
333+
// SWIFT_ENABLE_TENSORFLOW
334+
case Node::Kind::AutoDiffParameterIndices:
335+
case Node::Kind::AutoDiffResultIndex:
336+
case Node::Kind::AutoDiffJVP:
337+
case Node::Kind::AutoDiffVJP:
338+
case Node::Kind::AutoDiffDifferential:
339+
case Node::Kind::AutoDiffPullback:
340+
// SWIFT_ENABLE_TENSORFLOW END
333341
case Node::Kind::BaseConformanceDescriptor:
334342
case Node::Kind::BaseWitnessTableAccessor:
335343
case Node::Kind::ClassMetadataBaseOffset:
@@ -665,6 +673,25 @@ class NodePrinter {
665673
}
666674
}
667675

676+
// SWIFT_ENABLE_TENSORFLOW
677+
void printAutoDiffAssociatedFunction(NodePointer Node) {
678+
if (Node->getKind() == Node::Kind::AutoDiffJVP)
679+
Printer << "JVP ";
680+
else if (Node->getKind() == Node::Kind::AutoDiffVJP)
681+
Printer << "VJP ";
682+
else if (Node->getKind() == Node::Kind::AutoDiffDifferential)
683+
Printer << "differential ";
684+
else if (Node->getKind() == Node::Kind::AutoDiffPullback)
685+
Printer << "pullback ";
686+
else
687+
assert(false && "Unknown autodiff associated function kind");
688+
print(Node->getChild(1)); // wrt param indices
689+
print(Node->getChild(2)); // result index
690+
Printer << "for ";
691+
print(Node->getChild(0)); // original function
692+
}
693+
// SWIFT_ENABLE_TENSORFLOW END
694+
668695
NodePointer getChildIf(NodePointer Node, Node::Kind Kind) {
669696
auto result =
670697
std::find_if(Node->begin(), Node->end(), [&](NodePointer child) {
@@ -1626,6 +1653,28 @@ NodePointer NodePrinter::print(NodePointer Node, bool asPrefixContext) {
16261653
print(Node->getChild(idx));
16271654
return nullptr;
16281655
}
1656+
// SWIFT_ENABLE_TENSORFLOW
1657+
case Node::Kind::AutoDiffParameterIndices: {
1658+
Printer << "wrt ";
1659+
interleave(Node->begin(), Node->end(),
1660+
[&](NodePointer child) {
1661+
Printer << child->getIndex();
1662+
}, [&]() { Printer << ", "; });
1663+
Printer << ' ';
1664+
return nullptr;
1665+
}
1666+
case Node::Kind::AutoDiffResultIndex: {
1667+
Printer << "source " << Node->getIndex() << ' ';
1668+
return nullptr;
1669+
}
1670+
case Node::Kind::AutoDiffJVP:
1671+
case Node::Kind::AutoDiffVJP:
1672+
case Node::Kind::AutoDiffDifferential:
1673+
case Node::Kind::AutoDiffPullback: {
1674+
printAutoDiffAssociatedFunction(Node);
1675+
return nullptr;
1676+
}
1677+
// SWIFT_ENABLE_TENSORFLOW END
16291678
case Node::Kind::MergedFunction:
16301679
if (!Options.ShortenThunk) {
16311680
Printer << "merged ";

lib/Demangling/OldRemangler.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,32 @@ void Remangler::mangleReabstractionThunk(Node *node) {
730730
Buffer << "<reabstraction-thunk>";
731731
}
732732

733+
// SWIFT_ENABLE_TENSORFLOW
734+
void Remangler::mangleAutoDiffParameterIndices(Node *node) {
735+
Buffer << "<autodiff-parameter-indices>";
736+
}
737+
738+
void Remangler::mangleAutoDiffResultIndex(Node *node) {
739+
Buffer << "<autodiff-result-index>";
740+
}
741+
742+
void Remangler::mangleAutoDiffJVP(Node *node) {
743+
Buffer << "<autodiff-jvp-function>";
744+
}
745+
746+
void Remangler::mangleAutoDiffVJP(Node *node) {
747+
Buffer << "<autodiff-vjp-function>";
748+
}
749+
750+
void Remangler::mangleAutoDiffDifferential(Node *node) {
751+
Buffer << "<autodiff-differential-function>";
752+
}
753+
754+
void Remangler::mangleAutoDiffPullback(Node *node) {
755+
Buffer << "<autodiff-pullback-function>";
756+
}
757+
// SWIFT_ENABLE_TENSORFLOW END
758+
733759
void Remangler::mangleProtocolSelfConformanceWitness(Node *node) {
734760
Buffer << "TS";
735761
mangleSingleChildNode(node); // entity

lib/Demangling/Remangler.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ class Remangler : public RemanglerBase {
309309
void mangleAnyProtocolConformance(Node *node);
310310

311311
void mangleKeyPathThunkHelper(Node *node, StringRef op);
312+
// SWIFT_ENABLE_TENSORFLOW
313+
void mangleAutoDiffAssociatedFunctionHelper(Node *node, StringRef op);
314+
// SWIFT_ENABLE_TENSORFLOW END
312315

313316
#define NODE(ID) \
314317
void mangle##ID(Node *node);
@@ -1952,6 +1955,48 @@ void Remangler::mangleReabstractionThunkHelperWithSelf(Node *node) {
19521955
Buffer << "Ty";
19531956
}
19541957

1958+
// SWIFT_ENABLE_TENSORFLOW
1959+
void Remangler::mangleAutoDiffParameterIndices(Node *node) {
1960+
Buffer << 'p';
1961+
for (unsigned i = 0, n = node->getNumChildren(); i != n; ++i) {
1962+
auto child = node->getChild(i);
1963+
Buffer << child->getIndex();
1964+
if (i != n - 1)
1965+
Buffer << '_';
1966+
}
1967+
}
1968+
1969+
void Remangler::mangleAutoDiffResultIndex(Node *node) {
1970+
Buffer << 'r' << node->getIndex();
1971+
}
1972+
1973+
void
1974+
Remangler::mangleAutoDiffAssociatedFunctionHelper(Node *node, StringRef op) {
1975+
// TODO(TF-680): Mangle `[differentiable]` atttribute requirements as well.
1976+
assert(node->getNumChildren() == 3);
1977+
mangleChildNode(node, 0); // original function
1978+
Buffer << op;
1979+
mangleChildNode(node, 1); // wrt parameter indices
1980+
mangleChildNode(node, 2); // result index
1981+
}
1982+
1983+
void Remangler::mangleAutoDiffJVP(Node *node) {
1984+
mangleAutoDiffAssociatedFunctionHelper(node, "Tz");
1985+
}
1986+
1987+
void Remangler::mangleAutoDiffVJP(Node *node) {
1988+
mangleAutoDiffAssociatedFunctionHelper(node, "TZ");
1989+
}
1990+
1991+
void Remangler::mangleAutoDiffDifferential(Node *node) {
1992+
mangleAutoDiffAssociatedFunctionHelper(node, "Tu");
1993+
}
1994+
1995+
void Remangler::mangleAutoDiffPullback(Node *node) {
1996+
mangleAutoDiffAssociatedFunctionHelper(node, "TU");
1997+
}
1998+
// SWIFT_ENABLE_TENSORFLOW END
1999+
19552000
void Remangler::mangleReadAccessor(Node *node) {
19562001
mangleAbstractStorage(node->getFirstChild(), "r");
19572002
}

test/AutoDiff/closures.swift

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ struct InoutAliasableCapture {
2626

2727
// CHECK-LABEL: @{{.*}}InoutAliasableCapture{{.*}}foo{{.*}} : $@convention(method) (@inout InoutAliasableCapture) -> () {
2828
// CHECK: bb0([[SELF:%.*]] : $*InoutAliasableCapture):
29-
// CHECK: [[JVP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__jvp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
29+
// CHECK: // function_ref JVP wrt 0 source 0 for capturesMutableSelf #1 (t:) in InoutAliasableCapture.foo()
30+
// CHECK: [[JVP:%.*]] = function_ref @$s8closures21InoutAliasableCaptureV3fooyyF19capturesMutableSelfL_1tS2f_tFTzp0r0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
3031
// CHECK-NOT: retain_value_addr [[SELF]]
3132
// CHECK-NOT: copy_addr [[SELF]]
3233
// CHECK: [[JVP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[JVP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
33-
// CHECK: [[VJP:%.*]] = function_ref @{{.*}}capturesMutableSelf{{.*}}__vjp_src_0_wrt_0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34+
// CHECK: // function_ref VJP wrt 0 source 0 for capturesMutableSelf #1 (t:) in InoutAliasableCapture.foo()
35+
// CHECK: [[VJP:%.*]] = function_ref @$s8closures21InoutAliasableCaptureV3fooyyF19capturesMutableSelfL_1tS2f_tFTZp0r0 : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
3436
// CHECK-NOT: retain_value_addr [[SELF]]
3537
// CHECK-NOT: copy_addr [[SELF]]
3638
// CHECK: [[VJP_CAPTURED:%.*]] = partial_apply [callee_guaranteed] [[VJP]]([[SELF]]) : $@convention(thin) (Float, @inout_aliasable InoutAliasableCapture) -> (Float, @owned @callee_guaranteed (Float) -> Float)
@@ -43,10 +45,11 @@ public func closureCaptureMutable() {
4345
}
4446
}
4547

46-
// CHECK-LABEL: @AD__{{.*}}closureCaptureMutable{{.*}}___vjp_src_0_wrt_0
48+
// CHECK-LABEL: // VJP wrt 0 source 0 for closure #1 in closureCaptureMutable()
49+
// CHECK-NEXT: @$s8closures21closureCaptureMutableyyFS2fcfU_TZp0r0
4750
// CHECK: bb0({{%.*}} : $Float, [[INOUT_ARG:%.*]] : ${ var Float }):
48-
// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___pullback_src_0_wrt_0
49-
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})
51+
// CHECK: [[PB:%.*]] = function_ref @$s8closures21closureCaptureMutableyyFS2fcfU_TUp0r0
52+
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[PB]]({{.*}})
5053

5154
// TF-30: VJP return value should match the return type.
5255
struct TF_30 : Differentiable {

0 commit comments

Comments
 (0)