Skip to content

[AutoDiff] NFC: gardening. #30852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ class ASTMangler : public Mangler {
AutoDiffLinearMapKind kind,
AutoDiffConfig config);

/// Mangle the AutoDiff generated declaration for the given:
/// - Generated declaration kind: linear map struct or branching trace enum.
/// - Mangled original function name.
/// - Basic block number.
/// - Linear map kind: differential or pullback.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string
mangleAutoDiffGeneratedDeclaration(AutoDiffGeneratedDeclarationKind declKind,
StringRef origFnName, unsigned bbId,
AutoDiffLinearMapKind linearMapKind,
AutoDiffConfig config);

/// Mangle a SIL differentiability witness key:
/// - Mangled original function name.
/// - Parameter indices.
Expand Down
12 changes: 11 additions & 1 deletion include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
Expand Down Expand Up @@ -31,6 +31,7 @@
namespace swift {

class AnyFunctionType;
class SourceFile;
class SILFunctionType;
class TupleType;

Expand Down Expand Up @@ -164,6 +165,12 @@ struct DifferentiabilityWitnessFunctionKind {
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// The kind of a declaration generated by the differentiation transform.
enum class AutoDiffGeneratedDeclarationKind : uint8_t {
LinearMapStruct,
BranchingTraceEnum
};

/// SIL-level automatic differentiation indices. Consists of:
/// - Parameter indices: indices of parameters to differentiate with respect to.
/// - Result index: index of the result to differentiate from.
Expand Down Expand Up @@ -386,6 +393,9 @@ class TangentSpace {
/// derivative generic signature.
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;

/// Returns `true` iff differentiable programming is enabled.
bool isDifferentiableProgrammingEnabled(SourceFile &SF);

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,26 @@ class LinearMapInfo {
/// whose cases represent the predecessors/successors of the given original
/// block.
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
SILAutoDiffIndices indices,
CanGenericSignature genericSig,
SILLoopInfo *loopInfo);

/// Creates a struct declaration with the given JVP/VJP generic signature, for
/// storing the linear map values and predecessor/successor basic block of the
/// given original block.
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
SILAutoDiffIndices indices,
CanGenericSignature genericSig);

/// Adds a linear map field to the linear map struct.
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);

/// Given an `apply` instruction, conditionally adds a linear map struct field
/// for its linear map function if it is active.
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
SILAutoDiffIndices indices);
void addLinearMapToStruct(ADContext &context, ApplyInst *ai);

/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
/// branching enum field.
void generateDifferentiationDataStructures(ADContext &context,
SILAutoDiffIndices indices,
SILFunction *derivative);

public:
Expand Down
39 changes: 39 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,45 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
return result;
}

std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
AutoDiffGeneratedDeclarationKind declKind, StringRef origFnName,
unsigned bbId, AutoDiffLinearMapKind linearMapKind, AutoDiffConfig config) {
beginManglingWithoutPrefix();

Buffer << "_AD__" << origFnName << "_bb" + std::to_string(bbId);
switch (declKind) {
case AutoDiffGeneratedDeclarationKind::LinearMapStruct:
switch (linearMapKind) {
case AutoDiffLinearMapKind::Differential:
Buffer << "__DF__";
break;
case AutoDiffLinearMapKind::Pullback:
Buffer << "__PB__";
break;
}
break;
case AutoDiffGeneratedDeclarationKind::BranchingTraceEnum:
switch (linearMapKind) {
case AutoDiffLinearMapKind::Differential:
Buffer << "__Succ__";
break;
case AutoDiffLinearMapKind::Pullback:
Buffer << "__Pred__";
break;
}
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
}

std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
SILDifferentiabilityWitnessKey key) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
Expand Down
18 changes: 18 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/ImportCache.h"
#include "swift/AST/Module.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
Expand Down Expand Up @@ -124,6 +125,23 @@ void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << ')';
}

bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
auto &ctx = SF.getASTContext();
// Return true if differentiable programming is explicitly enabled.
if (ctx.LangOpts.EnableExperimentalDifferentiableProgramming)
return true;
// Otherwise, return true iff the `_Differentiation` module is imported in
// the given source file.
bool importsDifferentiationModule = false;
for (auto import : namelookup::getAllImports(&SF)) {
if (import.second->getName() == ctx.Id_Differentiation) {
importsDifferentiationModule = true;
break;
}
}
return importsDifferentiationModule;
}

// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
// most once (for curried method types) is sufficient.
static void unwrapCurryLevels(AnyFunctionType *fnTy,
Expand Down
69 changes: 30 additions & 39 deletions lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
: kind(kind), original(original), derivative(derivative),
activityInfo(activityInfo), indices(indices),
typeConverter(context.getTypeConverter()) {
generateDifferentiationDataStructures(context, indices, derivative);
generateDifferentiationDataStructures(context, derivative);
}

SILType LinearMapInfo::remapTypeInDerivative(SILType ty) {
Expand Down Expand Up @@ -122,27 +122,24 @@ void LinearMapInfo::computeAccessLevel(NominalTypeDecl *nominal,
}
}

EnumDecl *LinearMapInfo::createBranchingTraceDecl(
SILBasicBlock *originalBB, SILAutoDiffIndices indices,
CanGenericSignature genericSig, SILLoopInfo *loopInfo) {
EnumDecl *
LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
CanGenericSignature genericSig,
SILLoopInfo *loopInfo) {
assert(originalBB->getParent() == original);
auto &astCtx = original->getASTContext();
auto *moduleDecl = original->getModule().getSwiftModule();
auto &file = getDeclarationFileUnit();
// Create a branching trace enum.
std::string enumName;
switch (kind) {
case AutoDiffLinearMapKind::Differential:
enumName = "_AD__" + original->getName().str() + "_bb" +
std::to_string(originalBB->getDebugID()) + "__Succ__" +
indices.mangle();
break;
case AutoDiffLinearMapKind::Pullback:
enumName = "_AD__" + original->getName().str() + "_bb" +
std::to_string(originalBB->getDebugID()) + "__Pred__" +
indices.mangle();
break;
}
Mangle::ASTMangler mangler;
auto *resultIndices = IndexSubset::get(
original->getASTContext(),
original->getLoweredFunctionType()->getNumResults(), indices.source);
auto *parameterIndices = indices.parameters;
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
auto enumName = mangler.mangleAutoDiffGeneratedDeclaration(
AutoDiffGeneratedDeclarationKind::BranchingTraceEnum,
original->getName().str(), originalBB->getDebugID(), kind, config);
auto enumId = astCtx.getIdentifier(enumName);
auto loc = original->getLocation().getSourceLoc();
GenericParamList *genericParams = nullptr;
Expand Down Expand Up @@ -199,25 +196,21 @@ EnumDecl *LinearMapInfo::createBranchingTraceDecl(

StructDecl *
LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
SILAutoDiffIndices indices,
CanGenericSignature genericSig) {
assert(originalBB->getParent() == original);
auto *original = originalBB->getParent();
auto &astCtx = original->getASTContext();
auto &file = getDeclarationFileUnit();
std::string structName;
switch (kind) {
case swift::AutoDiffLinearMapKind::Differential:
structName = "_AD__" + original->getName().str() + "_bb" +
std::to_string(originalBB->getDebugID()) + "__DF__" +
indices.mangle();
break;
case swift::AutoDiffLinearMapKind::Pullback:
structName = "_AD__" + original->getName().str() + "_bb" +
std::to_string(originalBB->getDebugID()) + "__PB__" +
indices.mangle();
break;
}
// Create a linear map struct.
Mangle::ASTMangler mangler;
auto *resultIndices = IndexSubset::get(
original->getASTContext(),
original->getLoweredFunctionType()->getNumResults(), indices.source);
auto *parameterIndices = indices.parameters;
AutoDiffConfig config(parameterIndices, resultIndices, genericSig);
auto structName = mangler.mangleAutoDiffGeneratedDeclaration(
AutoDiffGeneratedDeclarationKind::LinearMapStruct,
original->getName().str(), originalBB->getDebugID(), kind, config);
auto structId = astCtx.getIdentifier(structName);
GenericParamList *genericParams = nullptr;
if (genericSig)
Expand Down Expand Up @@ -274,8 +267,7 @@ VarDecl *LinearMapInfo::addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
return linearMapDecl;
}

void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
SILAutoDiffIndices indices) {
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai) {
SmallVector<SILValue, 4> allResults;
SmallVector<unsigned, 8> activeParamIndices;
SmallVector<unsigned, 8> activeResultIndices;
Expand Down Expand Up @@ -379,7 +371,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
}

void LinearMapInfo::generateDifferentiationDataStructures(
ADContext &context, SILAutoDiffIndices indices, SILFunction *derivativeFn) {
ADContext &context, SILFunction *derivativeFn) {
auto &astCtx = original->getASTContext();
auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
auto *loopInfo = loopAnalysis->get(original);
Expand All @@ -392,8 +384,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(

// Create linear map struct for each original block.
for (auto &origBB : *original) {
auto *linearMapStruct =
createLinearMapStruct(&origBB, indices, derivativeFnGenSig);
auto *linearMapStruct = createLinearMapStruct(&origBB, derivativeFnGenSig);
linearMapStructs.insert({&origBB, linearMapStruct});
}

Expand All @@ -409,8 +400,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
break;
}
for (auto &origBB : *original) {
auto *traceEnum = createBranchingTraceDecl(&origBB, indices,
derivativeFnGenSig, loopInfo);
auto *traceEnum =
createBranchingTraceDecl(&origBB, derivativeFnGenSig, loopInfo);
branchingTraceDecls.insert({&origBB, traceEnum});
if (origBB.isEntry())
continue;
Expand All @@ -433,7 +424,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
continue;
LLVM_DEBUG(getADDebugStream()
<< "Adding linear map struct field for " << *ai);
addLinearMapToStruct(context, ai, indices);
addLinearMapToStruct(context, ai);
}
}
}
Expand Down
17 changes: 0 additions & 17 deletions lib/Sema/TypeChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,23 +410,6 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) {
#endif
}

bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
auto &ctx = SF.getASTContext();
// Return true if differentiable programming is explicitly enabled.
if (ctx.LangOpts.EnableExperimentalDifferentiableProgramming)
return true;
// Otherwise, return true iff the `_Differentiation` module is imported in
// the given source file.
bool importsDifferentiationModule = false;
for (auto import : namelookup::getAllImports(&SF)) {
if (import.second->getName() == ctx.Id_Differentiation) {
importsDifferentiationModule = true;
break;
}
}
return importsDifferentiationModule;
}

bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) {
auto &ctx = SF.getASTContext();
// Return true if `AdditiveArithmetic` derived conformances are explicitly
Expand Down
3 changes: 0 additions & 3 deletions lib/Sema/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1520,9 +1520,6 @@ bool isMemberOperator(FuncDecl *decl, Type type);
/// Complain if @objc or dynamic is used without importing Foundation.
void diagnoseAttrsRequiringFoundation(SourceFile &SF);

/// Returns `true` iff differentiable programming is enabled.
bool isDifferentiableProgrammingEnabled(SourceFile &SF);

/// Returns `true` iff `AdditiveArithmetic` derived conformances are enabled.
bool isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF);

Expand Down