Skip to content

Commit 5e5f484

Browse files
authored
Merge pull request #35259 from rxwei/autodiff-mangling
[AutoDiff] Mangle derivative functions and linear maps
2 parents 8fd4970 + ffe6064 commit 5e5f484

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+696
-247
lines changed

docs/ABI/Mangling.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ types where the metadata itself has unknown layout.)
229229
global ::= entity generic-signature? type type* 'Tk' // key path setter
230230
global ::= type generic-signature 'TH' // key path equality
231231
global ::= type generic-signature 'Th' // key path hasher
232+
global ::= global generic-signature? 'TJ' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff function
232233

233234
global ::= protocol 'TL' // protocol requirements base descriptor
234235
global ::= assoc-type-name 'Tl' // associated type descriptor
@@ -271,6 +272,16 @@ are always non-polymorphic ``<impl-function-type>`` types.
271272
``<VALUE-WITNESS-KIND>`` differentiates the kinds of value
272273
witness functions for a type.
273274

275+
::
276+
277+
AUTODIFF-FUNCTION-KIND ::= 'f' // JVP (forward-mode derivative)
278+
AUTODIFF-FUNCTION-KIND ::= 'r' // VJP (reverse-mode derivative)
279+
AUTODIFF-FUNCTION-KIND ::= 'd' // differential
280+
AUTODIFF-FUNCTION-KIND ::= 'p' // pullback
281+
282+
``<AUTODIFF-FUNCTION-KIND>`` differentiates the kinds of functions assocaited
283+
with a differentiable function used for differentiable programming.
284+
274285
::
275286

276287
global ::= generic-signature? type 'WOy' // Outlined copy
@@ -1004,6 +1015,13 @@ Numbers and Indexes
10041015
``<INDEX>`` is a production for encoding numbers in contexts that can't
10051016
end in a digit; it's optimized for encoding smaller numbers.
10061017

1018+
::
1019+
1020+
INDEX-SUBSET ::= ('S' | 'U')+
1021+
1022+
``<INDEX-SUBSET>`` is encoded like a bit vector and is optimized for encoding
1023+
indices with a small upper bound.
1024+
10071025
Function Specializations
10081026
~~~~~~~~~~~~~~~~~~~~~~~~
10091027

include/swift/AST/ASTMangler.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,23 @@ class ASTMangler : public Mangler {
166166
bool predefined);
167167

168168
/// Mangle the derivative function (JVP/VJP) for the given:
169-
/// - Mangled original function name.
169+
/// - Mangled original function declaration.
170170
/// - Derivative function kind.
171171
/// - Derivative function configuration: parameter/result indices and
172172
/// derivative generic signature.
173173
std::string
174-
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
175-
AutoDiffDerivativeFunctionKind kind,
176-
AutoDiffConfig config);
174+
mangleAutoDiffDerivativeFunction(const AbstractFunctionDecl *originalAFD,
175+
AutoDiffDerivativeFunctionKind kind,
176+
AutoDiffConfig config);
177177

178178
/// Mangle the linear map (differential/pullback) for the given:
179-
/// - Mangled original function name.
179+
/// - Mangled original function declaration.
180180
/// - Linear map kind.
181181
/// - Derivative function configuration: parameter/result indices and
182182
/// derivative generic signature.
183-
std::string mangleAutoDiffLinearMapHelper(StringRef name,
184-
AutoDiffLinearMapKind kind,
185-
AutoDiffConfig config);
183+
std::string mangleAutoDiffLinearMap(const AbstractFunctionDecl *originalAFD,
184+
AutoDiffLinearMapKind kind,
185+
AutoDiffConfig config);
186186

187187
/// Mangle the AutoDiff generated declaration for the given:
188188
/// - Generated declaration kind: linear map struct or branching trace enum.
@@ -255,6 +255,8 @@ class ASTMangler : public Mangler {
255255

256256
std::string mangleOpaqueTypeDecl(const ValueDecl *decl);
257257

258+
std::string mangleGenericSignature(const GenericSignature sig);
259+
258260
enum SpecialContext {
259261
ObjCContext,
260262
ClangImporterContext,
@@ -427,6 +429,12 @@ class ASTMangler : public Mangler {
427429
void appendSymbolicReference(SymbolicReferent referent);
428430

429431
void appendOpaqueDeclName(const OpaqueTypeDecl *opaqueDecl);
432+
433+
void beginManglingWithAutoDiffOriginalFunction(
434+
const AbstractFunctionDecl *afd);
435+
void appendAutoDiffFunctionParts(char functionKindCode,
436+
AutoDiffConfig config);
437+
void appendIndexSubset(IndexSubset *indexSubset);
430438
};
431439

432440
} // end namespace Mangle

include/swift/AST/AutoDiff.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/TypeAlignments.h"
2727
#include "swift/Basic/Range.h"
2828
#include "swift/Basic/SourceLoc.h"
29+
#include "swift/Demangling/Demangle.h"
2930
#include "llvm/ADT/StringExtras.h"
3031
#include "llvm/Support/Error.h"
3132

@@ -655,6 +656,17 @@ getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,
655656

656657
} // end namespace swift
657658

659+
namespace swift {
660+
namespace Demangle {
661+
662+
AutoDiffFunctionKind
663+
getAutoDiffFunctionKind(AutoDiffDerivativeFunctionKind kind);
664+
665+
AutoDiffFunctionKind getAutoDiffFunctionKind(AutoDiffLinearMapKind kind);
666+
667+
} // end namespace autodiff
668+
} // end namespace swift
669+
658670
namespace llvm {
659671

660672
using swift::AutoDiffConfig;

include/swift/Demangling/Demangle.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ enum class FunctionSigSpecializationParamKind : unsigned {
117117
ExistentialToGeneric = 1 << 10,
118118
};
119119

120+
enum class AutoDiffFunctionKind : char {
121+
JVP = 'f',
122+
VJP = 'r',
123+
Differential = 'd',
124+
Pullback = 'p',
125+
};
126+
120127
/// The pass that caused the specialization to occur. We use this to make sure
121128
/// that two passes that generate similar changes do not yield the same
122129
/// mangling. This currently cannot happen, so this is just a safety measure

include/swift/Demangling/DemangleNodes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,9 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)
308308

309309
// Added in Swift 5.5
310310
NODE(AsyncFunctionPointer)
311+
NODE(AutoDiffFunction)
312+
NODE(AutoDiffFunctionKind)
313+
NODE(IndexSubset)
311314

312315
#undef CONTEXT_NODE
313316
#undef NODE

include/swift/Demangling/Demangler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,8 @@ class Demangler : public NodeFactory {
569569

570570
NodePointer demangleTypeMangling();
571571
NodePointer demangleSymbolicReference(unsigned char rawKind);
572+
NodePointer demangleAutoDiffFunctionKind();
573+
NodePointer demangleIndexSubset();
572574

573575
bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,
574576
NodePointer &RetroactiveConformances);

include/swift/SILOptimizer/Differentiation/JVPCloner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ class JVPCloner final {
3939
///
4040
/// The parent JVP cloner stores the original function and an empty
4141
/// to-be-generated pullback function.
42-
explicit JVPCloner(ADContext &context, SILFunction *original,
43-
SILDifferentiabilityWitness *witness, SILFunction *jvp,
44-
DifferentiationInvoker invoker);
42+
explicit JVPCloner(ADContext &context, SILDifferentiabilityWitness *witness,
43+
SILFunction *jvp, DifferentiationInvoker invoker);
4544
~JVPCloner();
4645

4746
/// Performs JVP generation on the empty JVP function. Returns true if any

include/swift/SILOptimizer/Differentiation/VJPCloner.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ class VJPCloner final {
3939
///
4040
/// The parent VJP cloner stores the original function and an empty
4141
/// to-be-generated pullback function.
42-
explicit VJPCloner(ADContext &context, SILFunction *original,
43-
SILDifferentiabilityWitness *witness, SILFunction *vjp,
44-
DifferentiationInvoker invoker);
42+
explicit VJPCloner(ADContext &context, SILDifferentiabilityWitness *witness,
43+
SILFunction *vjp, DifferentiationInvoker invoker);
4544
~VJPCloner();
4645

4746
ADContext &getContext() const;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===------- DifferentiationMangler.h --------- differentiation -*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H
14+
#define SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H
15+
16+
#include "swift/AST/ASTMangler.h"
17+
#include "swift/AST/AutoDiff.h"
18+
#include "swift/Basic/NullablePtr.h"
19+
#include "swift/Demangling/Demangler.h"
20+
#include "swift/SIL/SILFunction.h"
21+
22+
namespace swift {
23+
namespace Mangle {
24+
25+
/// A mangler for generated differentiation functions.
26+
class DifferentiationMangler : public ASTMangler {
27+
public:
28+
DifferentiationMangler() {}
29+
/// Returns the mangled name for a differentiation function of the given kind.
30+
std::string mangle(SILFunction *originalFunction,
31+
Demangle::AutoDiffFunctionKind kind,
32+
AutoDiffConfig config);
33+
/// Returns the mangled name for a derivative function of the given kind.
34+
std::string mangleDerivativeFunction(SILFunction *originalFunction,
35+
AutoDiffDerivativeFunctionKind kind,
36+
AutoDiffConfig config);
37+
/// Returns the mangled name for a linear map of the given kind.
38+
std::string mangleLinearMap(SILFunction *originalFunction,
39+
AutoDiffLinearMapKind kind,
40+
AutoDiffConfig config);
41+
};
42+
43+
} // end namespace Mangle
44+
} // end namespace swift
45+
46+
#endif /* SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H */

lib/AST/ASTMangler.cpp

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "swift/AST/ASTMangler.h"
1818
#include "swift/AST/ASTContext.h"
1919
#include "swift/AST/ASTVisitor.h"
20+
#include "swift/AST/AutoDiff.h"
2021
#include "swift/AST/ExistentialLayout.h"
2122
#include "swift/AST/FileUnit.h"
2223
#include "swift/AST/GenericSignature.h"
@@ -405,55 +406,51 @@ std::string ASTMangler::mangleObjCAsyncCompletionHandlerImpl(
405406
return finalize();
406407
}
407408

408-
std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
409-
StringRef name, AutoDiffDerivativeFunctionKind kind,
409+
std::string ASTMangler::mangleAutoDiffDerivativeFunction(
410+
const AbstractFunctionDecl *originalAFD,
411+
AutoDiffDerivativeFunctionKind kind,
410412
AutoDiffConfig config) {
411-
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
412-
beginManglingWithoutPrefix();
413-
414-
Buffer << "AD__" << name << '_';
415-
switch (kind) {
416-
case AutoDiffDerivativeFunctionKind::JVP:
417-
Buffer << "_jvp_";
418-
break;
419-
case AutoDiffDerivativeFunctionKind::VJP:
420-
Buffer << "_vjp_";
421-
break;
422-
}
423-
Buffer << config.mangle();
424-
if (config.derivativeGenericSignature) {
425-
Buffer << '_';
426-
appendGenericSignature(config.derivativeGenericSignature);
427-
}
428-
429-
auto result = Storage.str().str();
430-
Storage.clear();
431-
return result;
413+
beginManglingWithAutoDiffOriginalFunction(originalAFD);
414+
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
415+
return finalize();
432416
}
433417

434-
std::string ASTMangler::mangleAutoDiffLinearMapHelper(
435-
StringRef name, AutoDiffLinearMapKind kind, AutoDiffConfig config) {
436-
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
437-
beginManglingWithoutPrefix();
418+
std::string ASTMangler::mangleAutoDiffLinearMap(
419+
const AbstractFunctionDecl *originalAFD, AutoDiffLinearMapKind kind,
420+
AutoDiffConfig config) {
421+
beginManglingWithAutoDiffOriginalFunction(originalAFD);
422+
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
423+
return finalize();
424+
}
438425

439-
Buffer << "AD__" << name << '_';
440-
switch (kind) {
441-
case AutoDiffLinearMapKind::Differential:
442-
Buffer << "_differential_";
443-
break;
444-
case AutoDiffLinearMapKind::Pullback:
445-
Buffer << "_pullback_";
446-
break;
447-
}
448-
Buffer << config.mangle();
449-
if (config.derivativeGenericSignature) {
450-
Buffer << '_';
451-
appendGenericSignature(config.derivativeGenericSignature);
426+
void ASTMangler::beginManglingWithAutoDiffOriginalFunction(
427+
const AbstractFunctionDecl *afd) {
428+
if (auto *attr = afd->getAttrs().getAttribute<SILGenNameAttr>()) {
429+
beginManglingWithoutPrefix();
430+
appendOperator(attr->Name);
431+
return;
452432
}
433+
beginMangling();
434+
if (auto *cd = dyn_cast<ConstructorDecl>(afd))
435+
appendConstructorEntity(cd, /*isAllocating*/ !cd->isConvenienceInit());
436+
else
437+
appendEntity(afd);
438+
}
453439

454-
auto result = Storage.str().str();
455-
Storage.clear();
456-
return result;
440+
void ASTMangler::appendAutoDiffFunctionParts(char functionKindCode,
441+
AutoDiffConfig config) {
442+
if (auto sig = config.derivativeGenericSignature)
443+
appendGenericSignature(sig);
444+
appendOperator("TJ", StringRef(&functionKindCode, 1));
445+
appendIndexSubset(config.parameterIndices);
446+
appendOperator("p");
447+
appendIndexSubset(config.resultIndices);
448+
appendOperator("r");
449+
}
450+
451+
/// Mangle the index subset.
452+
void ASTMangler::appendIndexSubset(IndexSubset *indices) {
453+
Buffer << indices->getString();
457454
}
458455

459456
std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
@@ -741,6 +738,12 @@ std::string ASTMangler::mangleOpaqueTypeDecl(const ValueDecl *decl) {
741738
return mangleDeclAsUSR(decl, MANGLING_PREFIX_STR);
742739
}
743740

741+
std::string ASTMangler::mangleGenericSignature(const GenericSignature sig) {
742+
beginMangling();
743+
appendGenericSignature(sig);
744+
return finalize();
745+
}
746+
744747
void ASTMangler::appendSymbolKind(SymbolKind SKind) {
745748
switch (SKind) {
746749
case SymbolKind::Default: return;

lib/AST/AutoDiff.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,23 @@ TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
558558
// Otherwise, tangent property is valid.
559559
return TangentPropertyInfo(tanField);
560560
}
561+
562+
Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
563+
AutoDiffDerivativeFunctionKind kind) {
564+
switch (kind) {
565+
case AutoDiffDerivativeFunctionKind::JVP:
566+
return Demangle::AutoDiffFunctionKind::JVP;
567+
case AutoDiffDerivativeFunctionKind::VJP:
568+
return Demangle::AutoDiffFunctionKind::VJP;
569+
}
570+
}
571+
572+
Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
573+
AutoDiffLinearMapKind kind) {
574+
switch (kind) {
575+
case AutoDiffLinearMapKind::Differential:
576+
return Demangle::AutoDiffFunctionKind::Differential;
577+
case AutoDiffLinearMapKind::Pullback:
578+
return Demangle::AutoDiffFunctionKind::Pullback;
579+
}
580+
}

0 commit comments

Comments
 (0)