Skip to content

Commit 5d8d8da

Browse files
authored
Merge pull request #30852 from dan-zheng/autodiff-cleanup
[AutoDiff] NFC: gardening.
2 parents 78880ff + 77d0d99 commit 5d8d8da

File tree

8 files changed

+112
-65
lines changed

8 files changed

+112
-65
lines changed

include/swift/AST/ASTMangler.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ class ASTMangler : public Mangler {
172172
AutoDiffLinearMapKind kind,
173173
AutoDiffConfig config);
174174

175+
/// Mangle the AutoDiff generated declaration for the given:
176+
/// - Generated declaration kind: linear map struct or branching trace enum.
177+
/// - Mangled original function name.
178+
/// - Basic block number.
179+
/// - Linear map kind: differential or pullback.
180+
/// - Derivative function configuration: parameter/result indices and
181+
/// derivative generic signature.
182+
std::string
183+
mangleAutoDiffGeneratedDeclaration(AutoDiffGeneratedDeclarationKind declKind,
184+
StringRef origFnName, unsigned bbId,
185+
AutoDiffLinearMapKind linearMapKind,
186+
AutoDiffConfig config);
187+
175188
/// Mangle a SIL differentiability witness key:
176189
/// - Mangled original function name.
177190
/// - Parameter indices.

include/swift/AST/AutoDiff.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2019 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -31,6 +31,7 @@
3131
namespace swift {
3232

3333
class AnyFunctionType;
34+
class SourceFile;
3435
class SILFunctionType;
3536
class TupleType;
3637

@@ -164,6 +165,12 @@ struct DifferentiabilityWitnessFunctionKind {
164165
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
165166
};
166167

168+
/// The kind of a declaration generated by the differentiation transform.
169+
enum class AutoDiffGeneratedDeclarationKind : uint8_t {
170+
LinearMapStruct,
171+
BranchingTraceEnum
172+
};
173+
167174
/// SIL-level automatic differentiation indices. Consists of:
168175
/// - Parameter indices: indices of parameters to differentiate with respect to.
169176
/// - Result index: index of the result to differentiate from.
@@ -386,6 +393,9 @@ class TangentSpace {
386393
/// derivative generic signature.
387394
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;
388395

396+
/// Returns `true` iff differentiable programming is enabled.
397+
bool isDifferentiableProgrammingEnabled(SourceFile &SF);
398+
389399
/// Automatic differentiation utility namespace.
390400
namespace autodiff {
391401

include/swift/SILOptimizer/Utils/Differentiation/LinearMapInfo.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,26 @@ class LinearMapInfo {
113113
/// whose cases represent the predecessors/successors of the given original
114114
/// block.
115115
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
116-
SILAutoDiffIndices indices,
117116
CanGenericSignature genericSig,
118117
SILLoopInfo *loopInfo);
119118

120119
/// Creates a struct declaration with the given JVP/VJP generic signature, for
121120
/// storing the linear map values and predecessor/successor basic block of the
122121
/// given original block.
123122
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
124-
SILAutoDiffIndices indices,
125123
CanGenericSignature genericSig);
126124

127125
/// Adds a linear map field to the linear map struct.
128126
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);
129127

130128
/// Given an `apply` instruction, conditionally adds a linear map struct field
131129
/// for its linear map function if it is active.
132-
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
133-
SILAutoDiffIndices indices);
130+
void addLinearMapToStruct(ADContext &context, ApplyInst *ai);
134131

135132
/// Generates linear map struct and branching enum declarations for the given
136133
/// function. Linear map structs are populated with linear map fields and a
137134
/// branching enum field.
138135
void generateDifferentiationDataStructures(ADContext &context,
139-
SILAutoDiffIndices indices,
140136
SILFunction *derivative);
141137

142138
public:

lib/AST/ASTMangler.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,45 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
424424
return result;
425425
}
426426

427+
std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
428+
AutoDiffGeneratedDeclarationKind declKind, StringRef origFnName,
429+
unsigned bbId, AutoDiffLinearMapKind linearMapKind, AutoDiffConfig config) {
430+
beginManglingWithoutPrefix();
431+
432+
Buffer << "_AD__" << origFnName << "_bb" + std::to_string(bbId);
433+
switch (declKind) {
434+
case AutoDiffGeneratedDeclarationKind::LinearMapStruct:
435+
switch (linearMapKind) {
436+
case AutoDiffLinearMapKind::Differential:
437+
Buffer << "__DF__";
438+
break;
439+
case AutoDiffLinearMapKind::Pullback:
440+
Buffer << "__PB__";
441+
break;
442+
}
443+
break;
444+
case AutoDiffGeneratedDeclarationKind::BranchingTraceEnum:
445+
switch (linearMapKind) {
446+
case AutoDiffLinearMapKind::Differential:
447+
Buffer << "__Succ__";
448+
break;
449+
case AutoDiffLinearMapKind::Pullback:
450+
Buffer << "__Pred__";
451+
break;
452+
}
453+
break;
454+
}
455+
Buffer << config.getSILAutoDiffIndices().mangle();
456+
if (config.derivativeGenericSignature) {
457+
Buffer << '_';
458+
appendGenericSignature(config.derivativeGenericSignature);
459+
}
460+
461+
auto result = Storage.str().str();
462+
Storage.clear();
463+
return result;
464+
}
465+
427466
std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
428467
SILDifferentiabilityWitnessKey key) {
429468
// TODO(TF-20): Make the mangling scheme robust. Support demangling.

lib/AST/AutoDiff.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "swift/AST/AutoDiff.h"
1414
#include "swift/AST/ASTContext.h"
1515
#include "swift/AST/GenericEnvironment.h"
16+
#include "swift/AST/ImportCache.h"
1617
#include "swift/AST/Module.h"
1718
#include "swift/AST/TypeCheckRequests.h"
1819
#include "swift/AST/Types.h"
@@ -124,6 +125,23 @@ void AutoDiffConfig::print(llvm::raw_ostream &s) const {
124125
s << ')';
125126
}
126127

128+
bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
129+
auto &ctx = SF.getASTContext();
130+
// Return true if differentiable programming is explicitly enabled.
131+
if (ctx.LangOpts.EnableExperimentalDifferentiableProgramming)
132+
return true;
133+
// Otherwise, return true iff the `_Differentiation` module is imported in
134+
// the given source file.
135+
bool importsDifferentiationModule = false;
136+
for (auto import : namelookup::getAllImports(&SF)) {
137+
if (import.second->getName() == ctx.Id_Differentiation) {
138+
importsDifferentiationModule = true;
139+
break;
140+
}
141+
}
142+
return importsDifferentiationModule;
143+
}
144+
127145
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
128146
// most once (for curried method types) is sufficient.
129147
static void unwrapCurryLevels(AnyFunctionType *fnTy,

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
5959
: kind(kind), original(original), derivative(derivative),
6060
activityInfo(activityInfo), indices(indices),
6161
typeConverter(context.getTypeConverter()) {
62-
generateDifferentiationDataStructures(context, indices, derivative);
62+
generateDifferentiationDataStructures(context, derivative);
6363
}
6464

6565
SILType LinearMapInfo::remapTypeInDerivative(SILType ty) {
@@ -122,27 +122,24 @@ void LinearMapInfo::computeAccessLevel(NominalTypeDecl *nominal,
122122
}
123123
}
124124

125-
EnumDecl *LinearMapInfo::createBranchingTraceDecl(
126-
SILBasicBlock *originalBB, SILAutoDiffIndices indices,
127-
CanGenericSignature genericSig, SILLoopInfo *loopInfo) {
125+
EnumDecl *
126+
LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
127+
CanGenericSignature genericSig,
128+
SILLoopInfo *loopInfo) {
128129
assert(originalBB->getParent() == original);
129130
auto &astCtx = original->getASTContext();
130131
auto *moduleDecl = original->getModule().getSwiftModule();
131132
auto &file = getDeclarationFileUnit();
132133
// Create a branching trace enum.
133-
std::string enumName;
134-
switch (kind) {
135-
case AutoDiffLinearMapKind::Differential:
136-
enumName = "_AD__" + original->getName().str() + "_bb" +
137-
std::to_string(originalBB->getDebugID()) + "__Succ__" +
138-
indices.mangle();
139-
break;
140-
case AutoDiffLinearMapKind::Pullback:
141-
enumName = "_AD__" + original->getName().str() + "_bb" +
142-
std::to_string(originalBB->getDebugID()) + "__Pred__" +
143-
indices.mangle();
144-
break;
145-
}
134+
Mangle::ASTMangler mangler;
135+
auto *resultIndices = IndexSubset::get(
136+
original->getASTContext(),
137+
original->getLoweredFunctionType()->getNumResults(), indices.source);
138+
auto *parameterIndices = indices.parameters;
139+
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
140+
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
141+
AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
142+
original->getName().str(), originalBB->getDebugID(), kind, config);
146143
auto enumId = astCtx.getIdentifier(enumName);
147144
auto loc = original->getLocation().getSourceLoc();
148145
GenericParamList *genericParams = nullptr;
@@ -199,25 +196,21 @@ EnumDecl *LinearMapInfo::createBranchingTraceDecl(
199196

200197
StructDecl *
201198
LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
202-
SILAutoDiffIndices indices,
203199
CanGenericSignature genericSig) {
204200
assert(originalBB->getParent() == original);
205201
auto *original = originalBB->getParent();
206202
auto &astCtx = original->getASTContext();
207203
auto &file = getDeclarationFileUnit();
208-
std::string structName;
209-
switch (kind) {
210-
case swift::AutoDiffLinearMapKind::Differential:
211-
structName = "_AD__" + original->getName().str() + "_bb" +
212-
std::to_string(originalBB->getDebugID()) + "__DF__" +
213-
indices.mangle();
214-
break;
215-
case swift::AutoDiffLinearMapKind::Pullback:
216-
structName = "_AD__" + original->getName().str() + "_bb" +
217-
std::to_string(originalBB->getDebugID()) + "__PB__" +
218-
indices.mangle();
219-
break;
220-
}
204+
// Create a linear map struct.
205+
Mangle::ASTMangler mangler;
206+
auto *resultIndices = IndexSubset::get(
207+
original->getASTContext(),
208+
original->getLoweredFunctionType()->getNumResults(), indices.source);
209+
auto *parameterIndices = indices.parameters;
210+
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
211+
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(
212+
AutoDiffGeneratedDeclarationKind::LinearMapStruct,
213+
original->getName().str(), originalBB->getDebugID(), kind, config);
221214
auto structId = astCtx.getIdentifier(structName);
222215
GenericParamList *genericParams = nullptr;
223216
if (genericSig)
@@ -274,8 +267,7 @@ VarDecl *LinearMapInfo::addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
274267
return linearMapDecl;
275268
}
276269

277-
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
278-
SILAutoDiffIndices indices) {
270+
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai) {
279271
SmallVector<SILValue, 4> allResults;
280272
SmallVector<unsigned, 8> activeParamIndices;
281273
SmallVector<unsigned, 8> activeResultIndices;
@@ -379,7 +371,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
379371
}
380372

381373
void LinearMapInfo::generateDifferentiationDataStructures(
382-
ADContext &context, SILAutoDiffIndices indices, SILFunction *derivativeFn) {
374+
ADContext &context, SILFunction *derivativeFn) {
383375
auto &astCtx = original->getASTContext();
384376
auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
385377
auto *loopInfo = loopAnalysis->get(original);
@@ -392,8 +384,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
392384

393385
// Create linear map struct for each original block.
394386
for (auto &origBB : *original) {
395-
auto *linearMapStruct =
396-
createLinearMapStruct(&origBB, indices, derivativeFnGenSig);
387+
auto *linearMapStruct = createLinearMapStruct(&origBB, derivativeFnGenSig);
397388
linearMapStructs.insert({&origBB, linearMapStruct});
398389
}
399390

@@ -409,8 +400,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
409400
break;
410401
}
411402
for (auto &origBB : *original) {
412-
auto *traceEnum = createBranchingTraceDecl(&origBB, indices,
413-
derivativeFnGenSig, loopInfo);
403+
auto *traceEnum =
404+
createBranchingTraceDecl(&origBB, derivativeFnGenSig, loopInfo);
414405
branchingTraceDecls.insert({&origBB, traceEnum});
415406
if (origBB.isEntry())
416407
continue;
@@ -433,7 +424,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
433424
continue;
434425
LLVM_DEBUG(getADDebugStream()
435426
<< "Adding linear map struct field for " << *ai);
436-
addLinearMapToStruct(context, ai, indices);
427+
addLinearMapToStruct(context, ai);
437428
}
438429
}
439430
}

lib/Sema/TypeChecker.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -410,23 +410,6 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) {
410410
#endif
411411
}
412412

413-
bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
414-
auto &ctx = SF.getASTContext();
415-
// Return true if differentiable programming is explicitly enabled.
416-
if (ctx.LangOpts.EnableExperimentalDifferentiableProgramming)
417-
return true;
418-
// Otherwise, return true iff the `_Differentiation` module is imported in
419-
// the given source file.
420-
bool importsDifferentiationModule = false;
421-
for (auto import : namelookup::getAllImports(&SF)) {
422-
if (import.second->getName() == ctx.Id_Differentiation) {
423-
importsDifferentiationModule = true;
424-
break;
425-
}
426-
}
427-
return importsDifferentiationModule;
428-
}
429-
430413
bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) {
431414
auto &ctx = SF.getASTContext();
432415
// Return true if `AdditiveArithmetic` derived conformances are explicitly

lib/Sema/TypeChecker.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,9 +1520,6 @@ bool isMemberOperator(FuncDecl *decl, Type type);
15201520
/// Complain if @objc or dynamic is used without importing Foundation.
15211521
void diagnoseAttrsRequiringFoundation(SourceFile &SF);
15221522

1523-
/// Returns `true` iff differentiable programming is enabled.
1524-
bool isDifferentiableProgrammingEnabled(SourceFile &SF);
1525-
15261523
/// Returns `true` iff `AdditiveArithmetic` derived conformances are enabled.
15271524
bool isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF);
15281525

0 commit comments

Comments
 (0)