Skip to content

Commit 04686a0

Browse files
committed
[AutoDiff] NFC: add ASTMangler entry points for AutoDiff-generated decls.
Add `ASTMangler::mangleAutoDiffGeneratedDeclaration`.
1 parent b2fe997 commit 04686a0

File tree

4 files changed

+77
-26
lines changed

4 files changed

+77
-26
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ class ASTMangler : public Mangler {
172172
AutoDiffLinearMapKind kind,
173173
AutoDiffConfig config);
174174

175+
/// Mangle the AutoDiff generated declaration for the given:
176+
/// - Generated declaration kind: linear map struct or branching trace enum.
177+
/// - Mangled original function name.
178+
/// - Basic block number.
179+
/// - Linear map kind: differential or pullback.
180+
/// - Derivative function configuration: parameter/result indices and
181+
/// derivative generic signature.
182+
std::string
183+
mangleAutoDiffGeneratedDeclaration(AutoDiffGeneratedDeclarationKind declKind,
184+
StringRef origFnName, unsigned bbId,
185+
AutoDiffLinearMapKind linearMapKind,
186+
AutoDiffConfig config);
187+
175188
/// Mangle a SIL differentiability witness key:
176189
/// - Mangled original function name.
177190
/// - Parameter indices.

include/swift/AST/AutoDiff.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ struct DifferentiabilityWitnessFunctionKind {
164164
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
165165
};
166166

167+
/// The kind of a declaration generated by the differentiation transform.
168+
enum class AutoDiffGeneratedDeclarationKind : uint8_t {
169+
LinearMapStruct,
170+
BranchingTraceEnum
171+
};
172+
167173
/// SIL-level automatic differentiation indices. Consists of:
168174
/// - Parameter indices: indices of parameters to differentiate with respect to.
169175
/// - Result index: index of the result to differentiate from.

lib/AST/ASTMangler.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,45 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
424424
return result;
425425
}
426426

427+
std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
428+
AutoDiffGeneratedDeclarationKind declKind, StringRef origFnName,
429+
unsigned bbId, AutoDiffLinearMapKind linearMapKind, AutoDiffConfig config) {
430+
beginManglingWithoutPrefix();
431+
432+
Buffer << "_AD__" << origFnName << "_bb" + std::to_string(bbId);
433+
switch (declKind) {
434+
case AutoDiffGeneratedDeclarationKind::LinearMapStruct:
435+
switch (linearMapKind) {
436+
case AutoDiffLinearMapKind::Differential:
437+
Buffer << "__DF__";
438+
break;
439+
case AutoDiffLinearMapKind::Pullback:
440+
Buffer << "__PB__";
441+
break;
442+
}
443+
break;
444+
case AutoDiffGeneratedDeclarationKind::BranchingTraceEnum:
445+
switch (linearMapKind) {
446+
case AutoDiffLinearMapKind::Differential:
447+
Buffer << "__Succ__";
448+
break;
449+
case AutoDiffLinearMapKind::Pullback:
450+
Buffer << "__Pred__";
451+
break;
452+
}
453+
break;
454+
}
455+
Buffer << config.getSILAutoDiffIndices().mangle();
456+
if (config.derivativeGenericSignature) {
457+
Buffer << '_';
458+
appendGenericSignature(config.derivativeGenericSignature);
459+
}
460+
461+
auto result = Storage.str().str();
462+
Storage.clear();
463+
return result;
464+
}
465+
427466
std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
428467
SILDifferentiabilityWitnessKey key) {
429468
// TODO(TF-20): Make the mangling scheme robust. Support demangling.

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,15 @@ EnumDecl *LinearMapInfo::createBranchingTraceDecl(
130130
auto *moduleDecl = original->getModule().getSwiftModule();
131131
auto &file = getDeclarationFileUnit();
132132
// Create a branching trace enum.
133-
std::string enumName;
134-
switch (kind) {
135-
case AutoDiffLinearMapKind::Differential:
136-
enumName = "_AD__" + original->getName().str() + "_bb" +
137-
std::to_string(originalBB->getDebugID()) + "__Succ__" +
138-
indices.mangle();
139-
break;
140-
case AutoDiffLinearMapKind::Pullback:
141-
enumName = "_AD__" + original->getName().str() + "_bb" +
142-
std::to_string(originalBB->getDebugID()) + "__Pred__" +
143-
indices.mangle();
144-
break;
145-
}
133+
Mangle::ASTMangler mangler;
134+
auto *resultIndices = IndexSubset::get(
135+
original->getASTContext(),
136+
original->getLoweredFunctionType()->getNumResults(), indices.source);
137+
auto *parameterIndices = indices.parameters;
138+
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
139+
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
140+
AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
141+
original->getName().str(), originalBB->getDebugID(), kind, config);
146142
auto enumId = astCtx.getIdentifier(enumName);
147143
auto loc = original->getLocation().getSourceLoc();
148144
GenericParamList *genericParams = nullptr;
@@ -205,19 +201,16 @@ LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
205201
auto *original = originalBB->getParent();
206202
auto &astCtx = original->getASTContext();
207203
auto &file = getDeclarationFileUnit();
208-
std::string structName;
209-
switch (kind) {
210-
case swift::AutoDiffLinearMapKind::Differential:
211-
structName = "_AD__" + original->getName().str() + "_bb" +
212-
std::to_string(originalBB->getDebugID()) + "__DF__" +
213-
indices.mangle();
214-
break;
215-
case swift::AutoDiffLinearMapKind::Pullback:
216-
structName = "_AD__" + original->getName().str() + "_bb" +
217-
std::to_string(originalBB->getDebugID()) + "__PB__" +
218-
indices.mangle();
219-
break;
220-
}
204+
// Create a linear map struct.
205+
Mangle::ASTMangler mangler;
206+
auto *resultIndices = IndexSubset::get(
207+
original->getASTContext(),
208+
original->getLoweredFunctionType()->getNumResults(), indices.source);
209+
auto *parameterIndices = indices.parameters;
210+
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
211+
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(
212+
AutoDiffGeneratedDeclarationKind::LinearMapStruct,
213+
original->getName().str(), originalBB->getDebugID(), kind, config);
221214
auto structId = astCtx.getIdentifier(structName);
222215
GenericParamList *genericParams = nullptr;
223216
if (genericSig)

0 commit comments

Comments
 (0)