Skip to content

[AutoDiff] Mangle derivative functions and linear maps #35259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/ABI/Mangling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ types where the metadata itself has unknown layout.)
global ::= entity generic-signature? type type* 'Tk' // key path setter
global ::= type generic-signature 'TH' // key path equality
global ::= type generic-signature 'Th' // key path hasher
global ::= global generic-signature? 'TJ' AUTODIFF-FUNCTION-KIND INDEX-SUBSET 'p' INDEX-SUBSET 'r' // autodiff function

global ::= protocol 'TL' // protocol requirements base descriptor
global ::= assoc-type-name 'Tl' // associated type descriptor
Expand Down Expand Up @@ -271,6 +272,16 @@ are always non-polymorphic ``<impl-function-type>`` types.
``<VALUE-WITNESS-KIND>`` differentiates the kinds of value
witness functions for a type.

::

AUTODIFF-FUNCTION-KIND ::= 'f' // JVP (forward-mode derivative)
AUTODIFF-FUNCTION-KIND ::= 'r' // VJP (reverse-mode derivative)
AUTODIFF-FUNCTION-KIND ::= 'd' // differential
AUTODIFF-FUNCTION-KIND ::= 'p' // pullback

``<AUTODIFF-FUNCTION-KIND>`` differentiates the kinds of functions assocaited
with a differentiable function used for differentiable programming.

::

global ::= generic-signature? type 'WOy' // Outlined copy
Expand Down Expand Up @@ -1004,6 +1015,13 @@ Numbers and Indexes
``<INDEX>`` is a production for encoding numbers in contexts that can't
end in a digit; it's optimized for encoding smaller numbers.

::

INDEX-SUBSET ::= ('S' | 'U')+

``<INDEX-SUBSET>`` is encoded like a bit vector and is optimized for encoding
indices with a small upper bound.

Function Specializations
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
24 changes: 16 additions & 8 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,23 @@ class ASTMangler : public Mangler {
bool predefined);

/// Mangle the derivative function (JVP/VJP) for the given:
/// - Mangled original function name.
/// - Mangled original function declaration.
/// - Derivative function kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);
mangleAutoDiffDerivativeFunction(const AbstractFunctionDecl *originalAFD,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);

/// Mangle the linear map (differential/pullback) for the given:
/// - Mangled original function name.
/// - Mangled original function declaration.
/// - Linear map kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string mangleAutoDiffLinearMapHelper(StringRef name,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);
std::string mangleAutoDiffLinearMap(const AbstractFunctionDecl *originalAFD,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);

/// Mangle the AutoDiff generated declaration for the given:
/// - Generated declaration kind: linear map struct or branching trace enum.
Expand Down Expand Up @@ -255,6 +255,8 @@ class ASTMangler : public Mangler {

std::string mangleOpaqueTypeDecl(const ValueDecl *decl);

std::string mangleGenericSignature(const GenericSignature sig);

enum SpecialContext {
ObjCContext,
ClangImporterContext,
Expand Down Expand Up @@ -427,6 +429,12 @@ class ASTMangler : public Mangler {
void appendSymbolicReference(SymbolicReferent referent);

void appendOpaqueDeclName(const OpaqueTypeDecl *opaqueDecl);

void beginManglingWithAutoDiffOriginalFunction(
const AbstractFunctionDecl *afd);
void appendAutoDiffFunctionParts(char functionKindCode,
AutoDiffConfig config);
void appendIndexSubset(IndexSubset *indexSubset);
};

} // end namespace Mangle
Expand Down
12 changes: 12 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "swift/Demangling/Demangle.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Error.h"

Expand Down Expand Up @@ -655,6 +656,17 @@ getDifferentiabilityWitnessGenericSignature(GenericSignature origGenSig,

} // end namespace swift

namespace swift {
namespace Demangle {

AutoDiffFunctionKind
getAutoDiffFunctionKind(AutoDiffDerivativeFunctionKind kind);

AutoDiffFunctionKind getAutoDiffFunctionKind(AutoDiffLinearMapKind kind);

} // end namespace autodiff
} // end namespace swift

namespace llvm {

using swift::AutoDiffConfig;
Expand Down
7 changes: 7 additions & 0 deletions include/swift/Demangling/Demangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ enum class FunctionSigSpecializationParamKind : unsigned {
ExistentialToGeneric = 1 << 10,
};

enum class AutoDiffFunctionKind : char {
JVP = 'f',
VJP = 'r',
Differential = 'd',
Pullback = 'p',
};

/// The pass that caused the specialization to occur. We use this to make sure
/// that two passes that generate similar changes do not yield the same
/// mangling. This currently cannot happen, so this is just a safety measure
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Demangling/DemangleNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ NODE(CanonicalPrespecializedGenericTypeCachingOnceToken)

// Added in Swift 5.5
NODE(AsyncFunctionPointer)
NODE(AutoDiffFunction)
NODE(AutoDiffFunctionKind)
NODE(IndexSubset)

#undef CONTEXT_NODE
#undef NODE
2 changes: 2 additions & 0 deletions include/swift/Demangling/Demangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ class Demangler : public NodeFactory {

NodePointer demangleTypeMangling();
NodePointer demangleSymbolicReference(unsigned char rawKind);
NodePointer demangleAutoDiffFunctionKind();
NodePointer demangleIndexSubset();

bool demangleBoundGenerics(Vector<NodePointer> &TypeListList,
NodePointer &RetroactiveConformances);
Expand Down
5 changes: 2 additions & 3 deletions include/swift/SILOptimizer/Differentiation/JVPCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ class JVPCloner final {
///
/// The parent JVP cloner stores the original function and an empty
/// to-be-generated pullback function.
explicit JVPCloner(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness, SILFunction *jvp,
DifferentiationInvoker invoker);
explicit JVPCloner(ADContext &context, SILDifferentiabilityWitness *witness,
SILFunction *jvp, DifferentiationInvoker invoker);
~JVPCloner();

/// Performs JVP generation on the empty JVP function. Returns true if any
Expand Down
5 changes: 2 additions & 3 deletions include/swift/SILOptimizer/Differentiation/VJPCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ class VJPCloner final {
///
/// The parent VJP cloner stores the original function and an empty
/// to-be-generated pullback function.
explicit VJPCloner(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness, SILFunction *vjp,
DifferentiationInvoker invoker);
explicit VJPCloner(ADContext &context, SILDifferentiabilityWitness *witness,
SILFunction *vjp, DifferentiationInvoker invoker);
~VJPCloner();

ADContext &getContext() const;
Expand Down
46 changes: 46 additions & 0 deletions include/swift/SILOptimizer/Utils/DifferentiationMangler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===------- DifferentiationMangler.h --------- differentiation -*- C++ -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H
#define SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H

#include "swift/AST/ASTMangler.h"
#include "swift/AST/AutoDiff.h"
#include "swift/Basic/NullablePtr.h"
#include "swift/Demangling/Demangler.h"
#include "swift/SIL/SILFunction.h"

namespace swift {
namespace Mangle {

/// A mangler for generated differentiation functions.
class DifferentiationMangler : public ASTMangler {
public:
DifferentiationMangler() {}
/// Returns the mangled name for a differentiation function of the given kind.
std::string mangle(SILFunction *originalFunction,
Demangle::AutoDiffFunctionKind kind,
AutoDiffConfig config);
/// Returns the mangled name for a derivative function of the given kind.
std::string mangleDerivativeFunction(SILFunction *originalFunction,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);
/// Returns the mangled name for a linear map of the given kind.
std::string mangleLinearMap(SILFunction *originalFunction,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);
};

} // end namespace Mangle
} // end namespace swift

#endif /* SWIFT_SIL_UTILS_DIFFERENTIATIONMANGLER_H */
89 changes: 46 additions & 43 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ExistentialLayout.h"
#include "swift/AST/FileUnit.h"
#include "swift/AST/GenericSignature.h"
Expand Down Expand Up @@ -405,55 +406,51 @@ std::string ASTMangler::mangleObjCAsyncCompletionHandlerImpl(
return finalize();
}

std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
std::string ASTMangler::mangleAutoDiffDerivativeFunction(
const AbstractFunctionDecl *originalAFD,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
Buffer << "_jvp_";
break;
case AutoDiffDerivativeFunctionKind::VJP:
Buffer << "_vjp_";
break;
}
Buffer << config.mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
beginManglingWithAutoDiffOriginalFunction(originalAFD);
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
return finalize();
}

std::string ASTMangler::mangleAutoDiffLinearMapHelper(
StringRef name, AutoDiffLinearMapKind kind, AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();
std::string ASTMangler::mangleAutoDiffLinearMap(
const AbstractFunctionDecl *originalAFD, AutoDiffLinearMapKind kind,
AutoDiffConfig config) {
beginManglingWithAutoDiffOriginalFunction(originalAFD);
appendAutoDiffFunctionParts((char)getAutoDiffFunctionKind(kind), config);
return finalize();
}

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffLinearMapKind::Differential:
Buffer << "_differential_";
break;
case AutoDiffLinearMapKind::Pullback:
Buffer << "_pullback_";
break;
}
Buffer << config.mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
void ASTMangler::beginManglingWithAutoDiffOriginalFunction(
const AbstractFunctionDecl *afd) {
if (auto *attr = afd->getAttrs().getAttribute<SILGenNameAttr>()) {
beginManglingWithoutPrefix();
appendOperator(attr->Name);
return;
}
beginMangling();
if (auto *cd = dyn_cast<ConstructorDecl>(afd))
appendConstructorEntity(cd, /*isAllocating*/ !cd->isConvenienceInit());
else
appendEntity(afd);
}

auto result = Storage.str().str();
Storage.clear();
return result;
void ASTMangler::appendAutoDiffFunctionParts(char functionKindCode,
AutoDiffConfig config) {
if (auto sig = config.derivativeGenericSignature)
appendGenericSignature(sig);
appendOperator("TJ", StringRef(&functionKindCode, 1));
appendIndexSubset(config.parameterIndices);
appendOperator("p");
appendIndexSubset(config.resultIndices);
appendOperator("r");
}

/// Mangle the index subset.
void ASTMangler::appendIndexSubset(IndexSubset *indices) {
Buffer << indices->getString();
}

std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
Expand Down Expand Up @@ -741,6 +738,12 @@ std::string ASTMangler::mangleOpaqueTypeDecl(const ValueDecl *decl) {
return mangleDeclAsUSR(decl, MANGLING_PREFIX_STR);
}

std::string ASTMangler::mangleGenericSignature(const GenericSignature sig) {
beginMangling();
appendGenericSignature(sig);
return finalize();
}

void ASTMangler::appendSymbolKind(SymbolKind SKind) {
switch (SKind) {
case SymbolKind::Default: return;
Expand Down
20 changes: 20 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,23 @@ TangentPropertyInfo TangentStoredPropertyRequest::evaluate(
// Otherwise, tangent property is valid.
return TangentPropertyInfo(tanField);
}

Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
AutoDiffDerivativeFunctionKind kind) {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
return Demangle::AutoDiffFunctionKind::JVP;
case AutoDiffDerivativeFunctionKind::VJP:
return Demangle::AutoDiffFunctionKind::VJP;
}
}

Demangle::AutoDiffFunctionKind Demangle::getAutoDiffFunctionKind(
AutoDiffLinearMapKind kind) {
switch (kind) {
case AutoDiffLinearMapKind::Differential:
return Demangle::AutoDiffFunctionKind::Differential;
case AutoDiffLinearMapKind::Pullback:
return Demangle::AutoDiffFunctionKind::Pullback;
}
}
Loading