Skip to content

Commit cf28c00

Browse files
authored
[AutoDiff] Use ASTMangler for generated AD associated functions. (#26389)
Use ASTMangler for generated AD associated functions and linear maps: JVPs, VJPs, differentials, pullbacks. This centralizes ad-hoc mangling logic scattered throughout the codebase. Add `AutoDiffLinearMapKind` enum. Todos: - TF-680: Mangle `@differentiable` attribute requirements. - TF-685: Use ASTMangler for AD-related thunks. - TF-686: Make Demangler/Remangler work with AD-generated functions.
1 parent 0bf4460 commit cf28c00

File tree

9 files changed

+143
-70
lines changed

9 files changed

+143
-70
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,25 @@ class ASTMangler : public Mangler {
151151
Type FromType, Type ToType,
152152
Type SelfType,
153153
ModuleDecl *Module);
154-
154+
155+
// SWIFT_ENABLE_TENSORFLOW
156+
// Mangle the autodiff associated function (JVP/VJP) with the given:
157+
// - Mangled original function name.
158+
// - Associated function kind.
159+
// - Parameter/result indices.
160+
std::string mangleAutoDiffAssociatedFunctionHelper(
161+
StringRef name, AutoDiffAssociatedFunctionKind kind,
162+
const SILAutoDiffIndices &indices);
163+
164+
// SWIFT_ENABLE_TENSORFLOW
165+
// Mangle the autodiff linear map (differential/pullback) with the given:
166+
// - Mangled original function name.
167+
// - Linear map kind.
168+
// - Parameter/result indices.
169+
std::string mangleAutoDiffLinearMapHelper(
170+
StringRef name, AutoDiffLinearMapKind kind,
171+
const SILAutoDiffIndices &indices);
172+
155173
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
156174
GenericSignature *signature,
157175
CanType baseType,

include/swift/AST/AutoDiff.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,20 @@ struct AutoDiffAssociatedFunctionKind {
546546
operator innerty() const { return rawValue; }
547547
};
548548

549+
/// The kind of an linear map.
550+
struct AutoDiffLinearMapKind {
551+
enum innerty : uint8_t {
552+
// The differential function.
553+
Differential = 0,
554+
// The pullback function.
555+
Pullback = 1
556+
} rawValue;
557+
558+
AutoDiffLinearMapKind() = default;
559+
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
560+
operator innerty() const { return rawValue; }
561+
};
562+
549563
/// In conjunction with the original function decl, identifies an associated
550564
/// autodiff function.
551565
///

lib/AST/ASTMangler.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,52 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
369369
return finalize();
370370
}
371371

372+
std::string ASTMangler::mangleAutoDiffAssociatedFunctionHelper(
373+
StringRef name, AutoDiffAssociatedFunctionKind kind,
374+
const SILAutoDiffIndices &indices) {
375+
// TODO(TF-20): Make the mangling scheme robust.
376+
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
377+
beginManglingWithoutPrefix();
378+
379+
Buffer << "AD__" << name << '_';
380+
switch (kind) {
381+
case AutoDiffAssociatedFunctionKind::JVP:
382+
Buffer << "_jvp_";
383+
break;
384+
case AutoDiffAssociatedFunctionKind::VJP:
385+
Buffer << "_vjp_";
386+
break;
387+
}
388+
Buffer << indices.mangle();
389+
390+
auto result = Storage.str().str();
391+
Storage.clear();
392+
return result;
393+
}
394+
395+
std::string ASTMangler::mangleAutoDiffLinearMapHelper(
396+
StringRef name, AutoDiffLinearMapKind kind,
397+
const SILAutoDiffIndices &indices) {
398+
// TODO(TF-20): Make the mangling scheme robust.
399+
// TODO(TF-680): Mangle `@differentiable` atttribute requirements as well.
400+
beginManglingWithoutPrefix();
401+
402+
Buffer << "AD__" << name << '_';
403+
switch (kind) {
404+
case AutoDiffLinearMapKind::Differential:
405+
Buffer << "_differential_";
406+
break;
407+
case AutoDiffLinearMapKind::Pullback:
408+
Buffer << "_pullback_";
409+
break;
410+
}
411+
Buffer << indices.mangle();
412+
413+
auto result = Storage.str().str();
414+
Storage.clear();
415+
return result;
416+
}
417+
372418
std::string ASTMangler::mangleTypeForDebugger(Type Ty, const DeclContext *DC) {
373419
PrettyStackTraceType prettyStackTrace(Ty->getASTContext(),
374420
"mangling type for debugger", Ty);

lib/SIL/SILDeclRef.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,9 @@ static void mangleClangDecl(raw_ostream &buffer,
677677
}
678678

679679
std::string SILDeclRef::mangle(ManglingKind MKind) const {
680+
using namespace Mangle;
681+
ASTMangler mangler;
682+
680683
// SWIFT_ENABLE_TENSORFLOW
681684
if (autoDiffAssociatedFunctionIdentifier) {
682685
std::string originalMangled = asAutoDiffOriginalFunction().mangle(MKind);
@@ -686,22 +689,11 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
686689
autoDiffAssociatedFunctionIdentifier->getParameterIndices()->getLowered(
687690
functionTy->getASTContext(), functionTy);
688691
SILAutoDiffIndices indices(/*source*/ 0, silParameterIndices);
689-
std::string mangledKind;
690-
switch (autoDiffAssociatedFunctionIdentifier->getKind()) {
691-
case AutoDiffAssociatedFunctionKind::JVP:
692-
mangledKind = "jvp";
693-
break;
694-
case AutoDiffAssociatedFunctionKind::VJP:
695-
mangledKind = "vjp";
696-
break;
697-
}
698-
return "AD__" + originalMangled + "__" + mangledKind + "_" +
699-
indices.mangle();
692+
auto assocFnKind = autoDiffAssociatedFunctionIdentifier->getKind();
693+
return mangler.mangleAutoDiffAssociatedFunctionHelper(
694+
originalMangled, assocFnKind, indices);
700695
}
701696

702-
using namespace Mangle;
703-
ASTMangler mangler;
704-
705697
// As a special case, Clang functions and globals don't get mangled at all.
706698
if (hasDecl()) {
707699
if (auto clangDecl = getDecl()->getClangDecl()) {

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "swift/SIL/SILFunctionBuilder.h"
1414
#include "swift/AST/Availability.h"
1515
#include "swift/AST/Decl.h"
16+
// SWIFT_ENABLE_TENSORFLOW
17+
#include "swift/AST/ASTMangler.h"
1618
using namespace swift;
1719

1820
SILFunction *SILFunctionBuilder::getOrCreateFunction(
@@ -104,12 +106,19 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
104106
indices.parameters->getNumIndices() > 1;
105107
if (isSelfReorderedMethod) {
106108
auto &ctx = F->getASTContext();
107-
if (A->getJVPFunction())
109+
if (A->getJVPFunction()) {
110+
Mangle::ASTMangler mangler;
108111
jvpName = ctx.getIdentifier(
109-
"AD__" + constant.mangle() + "__jvp_" + indices.mangle()).str();
112+
mangler.mangleAutoDiffAssociatedFunctionHelper(
113+
constant.mangle(), AutoDiffAssociatedFunctionKind::JVP,
114+
indices)).str();
115+
}
110116
if (A->getVJPFunction()) {
117+
Mangle::ASTMangler mangler;
111118
vjpName = ctx.getIdentifier(
112-
"AD__" + constant.mangle() + "__vjp_" + indices.mangle()).str();
119+
mangler.mangleAutoDiffAssociatedFunctionHelper(
120+
constant.mangle(), AutoDiffAssociatedFunctionKind::VJP,
121+
indices)).str();
113122
}
114123
} else {
115124
if (auto *jvpFn = A->getJVPFunction())

lib/SILGen/SILGenPoly.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,18 +3508,10 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk(
35083508
IsSerialized_t isSerialized) {
35093509
auto assocFnType = assocFn->getLoweredFunctionType();
35103510

3511-
std::string name;
3512-
switch (assocFnKind) {
3513-
case AutoDiffAssociatedFunctionKind::JVP:
3514-
name = "jvp";
3515-
break;
3516-
case AutoDiffAssociatedFunctionKind::VJP:
3517-
name = "vjp";
3518-
break;
3519-
}
3520-
name = getASTContext().getIdentifier(
3521-
"AD__" + original->getName().str() + "__" + name + "_" +
3522-
indices.mangle()).str();
3511+
Mangle::ASTMangler mangler;
3512+
auto name = getASTContext().getIdentifier(
3513+
mangler.mangleAutoDiffAssociatedFunctionHelper(
3514+
original->getName(), assocFnKind, indices)).str();
35233515

35243516
Lowering::GenericContextScope genericContextScope(
35253517
Types, assocFnType->getGenericSignature());

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,7 @@ emitAssociatedFunctionReference(
20572057
auto substMap = getSubstitutionMap(original);
20582058
// Attempt to look up a `[differentiable]` attribute that minimally
20592059
// satisfies the specified indices.
2060-
// TODO(TF-482): Change `lookupMinimalDifferentiableAttr` to additionally
2060+
// TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally
20612061
// check whether `[differentiable]` attribute generic requirements are
20622062
// satisfied.
20632063
auto *minimalAttr =
@@ -2112,7 +2112,7 @@ emitAssociatedFunctionReference(
21122112
}
21132113
assert(minimalAttr);
21142114
// TODO(TF-482): Move generic requirement checking logic to
2115-
// `lookupMinimalDifferentiableAttr`.
2115+
// `lookUpMinimalDifferentiableAttr`.
21162116
if (!checkRequirementsSatisfied(
21172117
minimalAttr->getRequirements(),
21182118
substMap, originalFn, context.getModule().getSwiftModule())) {
@@ -2959,10 +2959,11 @@ class VJPEmitter final
29592959
->getCanonicalType(), origParam.getConvention()));
29602960
}
29612961

2962-
auto pbName = original->getASTContext()
2963-
.getIdentifier("AD__" + original->getName().str() +
2964-
"__pullback_" + indices.mangle())
2965-
.str();
2962+
Mangle::ASTMangler mangler;
2963+
auto pbName = original->getASTContext().getIdentifier(
2964+
mangler.mangleAutoDiffLinearMapHelper(
2965+
original->getName(), AutoDiffLinearMapKind::Pullback,
2966+
indices)).str();
29662967
auto pbGenericSig = getAssociatedFunctionGenericSignature(attr, original);
29672968
auto *pbGenericEnv = pbGenericSig
29682969
? pbGenericSig->createGenericEnvironment()
@@ -3460,7 +3461,7 @@ class VJPEmitter final
34603461
void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) {
34613462
// Clone `autodiff_function` from original to VJP, then add the cloned
34623463
// instruction to the `autodiff_function` worklist.
3463-
SILClonerWithScopes::visitAutoDiffFunctionInst(adfi);
3464+
TypeSubstCloner::visitAutoDiffFunctionInst(adfi);
34643465
auto *newADFI = cast<AutoDiffFunctionInst>(getOpValue(adfi));
34653466
context.getAutoDiffFunctionInsts().push_back(newADFI);
34663467
}
@@ -3550,11 +3551,11 @@ class JVPEmitter final
35503551
->getAutoDiffAssociatedTangentSpace(lookupConformance)
35513552
->getCanonicalType(), origParam.getConvention()));
35523553
}
3553-
3554-
auto diffName = original->getASTContext()
3555-
.getIdentifier("AD__" + original->getName().str() + "__differential_" +
3556-
indices.mangle())
3557-
.str();
3554+
Mangle::ASTMangler mangler;
3555+
auto diffName = original->getASTContext().getIdentifier(
3556+
mangler.mangleAutoDiffLinearMapHelper(
3557+
original->getName(), AutoDiffLinearMapKind::Differential,
3558+
indices)).str();
35583559
auto diffGenericSig = getAssociatedFunctionGenericSignature(attr, original);
35593560
auto *diffGenericEnv = diffGenericSig
35603561
? diffGenericSig->createGenericEnvironment()
@@ -5918,10 +5919,11 @@ static SILFunction *createEmptyVJP(
59185919
auto indices = attr->getIndices();
59195920

59205921
// === Create an empty VJP. ===
5921-
auto vjpName = original->getASTContext()
5922-
.getIdentifier("AD__" + original->getName().str() +
5923-
"__vjp_" + indices.mangle())
5924-
.str();
5922+
Mangle::ASTMangler mangler;
5923+
auto vjpName = original->getASTContext().getIdentifier(
5924+
mangler.mangleAutoDiffAssociatedFunctionHelper(
5925+
original->getName(), AutoDiffAssociatedFunctionKind::VJP, indices))
5926+
.str();
59255927
auto vjpGenericSig = getAssociatedFunctionGenericSignature(attr, original);
59265928

59275929
// RAII that pushes the original function's generic signature to
@@ -5968,10 +5970,11 @@ static SILFunction *createEmptyJVP(
59685970
auto indices = attr->getIndices();
59695971

59705972
// === Create an empty JVP. ===
5971-
auto jvpName = original->getASTContext()
5972-
.getIdentifier("AD__" + original->getName().str() +
5973-
"__jvp_" + indices.mangle())
5974-
.str();
5973+
Mangle::ASTMangler mangler;
5974+
auto jvpName = original->getASTContext().getIdentifier(
5975+
mangler.mangleAutoDiffAssociatedFunctionHelper(
5976+
original->getName(), AutoDiffAssociatedFunctionKind::JVP, indices))
5977+
.str();
59755978
auto jvpGenericSig = getAssociatedFunctionGenericSignature(attr, original);
59765979

59775980
// RAII that pushes the original function's generic signature to
@@ -6017,10 +6020,11 @@ bool ADContext::processDifferentiableAttribute(
60176020
if (attr->hasJVP()) {
60186021
jvpName = attr->getJVPName();
60196022
} else if (original->isExternalDeclaration()) {
6020-
jvpName = original->getASTContext()
6021-
.getIdentifier("AD__" + original->getName().str() +
6022-
"__jvp_" + attr->getIndices().mangle())
6023-
.str();
6023+
Mangle::ASTMangler mangler;
6024+
jvpName = original->getASTContext().getIdentifier(
6025+
mangler.mangleAutoDiffAssociatedFunctionHelper(
6026+
original->getName(), AutoDiffAssociatedFunctionKind::JVP,
6027+
attr->getIndices())).str();
60246028
}
60256029
if (!jvpName.empty()) {
60266030
jvp = module.lookUpFunction(jvpName);
@@ -6044,10 +6048,11 @@ bool ADContext::processDifferentiableAttribute(
60446048
if (attr->hasVJP()) {
60456049
vjpName = attr->getVJPName();
60466050
} else if (original->isExternalDeclaration()) {
6047-
vjpName = original->getASTContext()
6048-
.getIdentifier("AD__" + original->getName().str() +
6049-
"__vjp_" + attr->getIndices().mangle())
6050-
.str();
6051+
Mangle::ASTMangler mangler;
6052+
vjpName = original->getASTContext().getIdentifier(
6053+
mangler.mangleAutoDiffAssociatedFunctionHelper(
6054+
original->getName(), AutoDiffAssociatedFunctionKind::VJP,
6055+
attr->getIndices())).str();
60516056
}
60526057
if (!vjpName.empty()) {
60536058
vjp = module.lookUpFunction(vjpName);
@@ -6142,7 +6147,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
61426147
/*withoutActuallyEscaping*/ true,
61436148
DifferentiationThunkKind::Reabstraction);
61446149

6145-
// TODO: Use more principled mangling.
6150+
// TODO(TF-685): Use more principled mangling for thunks.
61466151
std::string thunkName;
61476152
switch (kind) {
61486153
case AutoDiffAssociatedFunctionKind::JVP:
@@ -6408,7 +6413,7 @@ ADContext::getOrCreateSubsetParametersThunkForAssociatedFunction(
64086413
->getAbstractFunctionDecl()->getNameStr();
64096414
}
64106415
assert(!origName.empty() && "Original function name could not be resolved");
6411-
// TODO: Use more principled mangling.
6416+
// TODO(TF-685): Use more principled mangling for thunks.
64126417
std::string thunkName;
64136418
switch (kind) {
64146419
case AutoDiffAssociatedFunctionKind::JVP:
@@ -6531,6 +6536,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
65316536
if (auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee())) {
65326537
SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices);
65336538
auto *thunk = thunkRef->getReferencedFunctionOrNull();
6539+
// TODO(TF-685): Use more principled mangling for thunks.
65346540
auto newThunkName = "AD__" + thunk->getName().str() +
65356541
"__differentiable_curry_thunk_" + desiredIndices.mangle();
65366542

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,6 @@ func weirdWrapper<T : Differentiable>(_ x: T) -> T {
3535
}
3636
_ = gradient(at: Float(1), in: { x in weirdWrapper(x) })
3737

38-
/*
39-
// FIXME(TF-482): This currently crashes during differentiation transform.
40-
// because `T` is not constrained to `Differentiable` in generated
41-
// `[differentiable]` attribute.
42-
@differentiable
43-
func directMissingConformance<T>(_ x: T) -> T {
44-
return x
45-
}
46-
*/
47-
4838
@differentiable
4939
func direct<T : Differentiable>(_ x: T) -> T {
5040
return x

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,12 @@ func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
613613
return x
614614
}
615615

616+
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
617+
@differentiable
618+
func missingConformance<T>(_ x: T) -> T {
619+
return x
620+
}
621+
616622
protocol ProtocolRequirements : Differentiable {
617623
// expected-note @+2 {{protocol requires initializer 'init(x:y:)' with type '(x: Float, y: Float)'}}
618624
@differentiable

0 commit comments

Comments
 (0)