Skip to content

[NFC][AutoDiff] Pass AutoDiffConfig by const ref. #37074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class ASTMangler : public Mangler {
std::string
mangleAutoDiffDerivativeFunction(const AbstractFunctionDecl *originalAFD,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config,
const AutoDiffConfig &config,
bool isVTableThunk = false);

/// Mangle the linear map (differential/pullback) for the given:
Expand All @@ -196,7 +196,7 @@ class ASTMangler : public Mangler {
/// derivative generic signature.
std::string mangleAutoDiffLinearMap(const AbstractFunctionDecl *originalAFD,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);

/// Mangle the linear map self parameter reordering thunk the given:
/// - Mangled original function declaration.
Expand All @@ -210,7 +210,7 @@ class ASTMangler : public Mangler {
/// Mangle a SIL differentiability witness.
std::string mangleSILDifferentiabilityWitness(StringRef originalName,
DifferentiabilityKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);

/// Mangle the AutoDiff generated declaration for the given:
/// - Generated declaration kind: linear map struct or branching trace enum.
Expand All @@ -223,7 +223,7 @@ class ASTMangler : public Mangler {
mangleAutoDiffGeneratedDeclaration(AutoDiffGeneratedDeclarationKind declKind,
StringRef origFnName, unsigned bbId,
AutoDiffLinearMapKind linearMapKind,
AutoDiffConfig config);
const AutoDiffConfig &config);

std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
GenericSignature signature,
Expand Down Expand Up @@ -453,7 +453,7 @@ class ASTMangler : public Mangler {
const AbstractFunctionDecl *afd);
void appendAutoDiffFunctionParts(StringRef op,
Demangle::AutoDiffFunctionKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);
void appendIndexSubset(IndexSubset *indexSubset);
};

Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5718,7 +5718,7 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);

protected:
// If a function has a body at all, we have either a parsed body AST node or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class DifferentiableActivityInfo {
IndexSubset *resultIndices) const;

/// Returns true if the given value is active for the given config.
bool isActive(SILValue value, AutoDiffConfig config) const {
bool isActive(SILValue value, const AutoDiffConfig &config) const {
return isActive(value, config.parameterIndices, config.resultIndices);
}

Expand All @@ -217,7 +217,7 @@ class DifferentiableActivityInfo {
IndexSubset *resultIndices) const;

/// Returns the activity of the given value for the given config.
Activity getActivity(SILValue value, AutoDiffConfig config) const {
Activity getActivity(SILValue value, const AutoDiffConfig &config) const {
return getActivity(value, config.parameterIndices, config.resultIndices);
}

Expand All @@ -227,7 +227,7 @@ class DifferentiableActivityInfo {
llvm::raw_ostream &s = llvm::dbgs()) const;

/// Prints activity information for the config of the given value.
void dump(SILValue value, AutoDiffConfig config,
void dump(SILValue value, const AutoDiffConfig &config,
llvm::raw_ostream &s = llvm::dbgs()) const {
return dump(value, config.parameterIndices, config.resultIndices, s);
}
Expand All @@ -238,7 +238,8 @@ class DifferentiableActivityInfo {
llvm::raw_ostream &s = llvm::dbgs()) const;

/// Prints all activity information for the given config.
void dump(AutoDiffConfig config, llvm::raw_ostream &s = llvm::dbgs()) const {
void dump(const AutoDiffConfig &config,
llvm::raw_ostream &s = llvm::dbgs()) const {
return dump(config.parameterIndices, config.resultIndices, s);
}
};
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void collectAllActualResultsInTypeOrder(
/// - The set of minimal parameter and result indices for differentiating the
/// `apply` instruction.
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, AutoDiffConfig parentConfig,
ApplyInst *ai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class LinearMapInfo {

explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *derivative,
AutoDiffConfig config,
const AutoDiffConfig &config,
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo);

Expand Down
4 changes: 2 additions & 2 deletions include/swift/SILOptimizer/Differentiation/Thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ SILValue reabstractFunction(
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
AutoDiffConfig actualConfig, ADContext &adContext);
AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig,
const AutoDiffConfig &actualConfig, ADContext &adContext);

/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SILOptimizer/Differentiation/VJPCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class VJPCloner final {
SILFunction &getVJP() const;
SILFunction &getPullback() const;
SILDifferentiabilityWitness *getWitness() const;
AutoDiffConfig getConfig() const;
const AutoDiffConfig &getConfig() const;
DifferentiationInvoker getInvoker() const;
LinearMapInfo &getPullbackInfo() const;
SILLoopInfo *getLoopInfo() const;
Expand Down
6 changes: 3 additions & 3 deletions include/swift/SILOptimizer/Utils/DifferentiationMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class DifferentiationMangler : public ASTMangler {
/// Returns the mangled name for a differentiation function of the given kind.
std::string mangleAutoDiffFunction(StringRef originalName,
Demangle::AutoDiffFunctionKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);
/// Returns the mangled name for a derivative function of the given kind.
std::string mangleDerivativeFunction(StringRef originalName,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);
/// Returns the mangled name for a linear map of the given kind.
std::string mangleLinearMap(StringRef originalName,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);
const AutoDiffConfig &config);
/// Returns the mangled name for a derivative function subset parameters
/// thunk.
std::string mangleDerivativeFunctionSubsetParametersThunk(
Expand Down
18 changes: 10 additions & 8 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ std::string ASTMangler::mangleObjCAsyncCompletionHandlerImpl(
std::string ASTMangler::mangleAutoDiffDerivativeFunction(
const AbstractFunctionDecl *originalAFD,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config,
const AutoDiffConfig &config,
bool isVTableThunk) {
beginManglingWithAutoDiffOriginalFunction(originalAFD);
appendAutoDiffFunctionParts(
Expand All @@ -434,7 +434,7 @@ std::string ASTMangler::mangleAutoDiffDerivativeFunction(

std::string ASTMangler::mangleAutoDiffLinearMap(
const AbstractFunctionDecl *originalAFD, AutoDiffLinearMapKind kind,
AutoDiffConfig config) {
const AutoDiffConfig &config) {
beginManglingWithAutoDiffOriginalFunction(originalAFD);
appendAutoDiffFunctionParts("TJ", getAutoDiffFunctionKind(kind), config);
return finalize();
Expand All @@ -456,7 +456,7 @@ void ASTMangler::beginManglingWithAutoDiffOriginalFunction(

void ASTMangler::appendAutoDiffFunctionParts(StringRef op,
AutoDiffFunctionKind kind,
AutoDiffConfig config) {
const AutoDiffConfig &config) {
if (auto sig = config.derivativeGenericSignature)
appendGenericSignature(sig);
auto kindCode = (char)kind;
Expand Down Expand Up @@ -486,8 +486,8 @@ void ASTMangler::appendIndexSubset(IndexSubset *indices) {
}

static NodePointer mangleSILDifferentiabilityWitnessAsNode(
StringRef originalName, DifferentiabilityKind kind, AutoDiffConfig config,
Demangler &demangler) {
StringRef originalName, DifferentiabilityKind kind,
const AutoDiffConfig &config, Demangler &demangler) {
auto *diffWitnessNode = demangler.createNode(
Node::Kind::DifferentiabilityWitness);
auto origNode = demangler.demangleSymbol(originalName);
Expand Down Expand Up @@ -518,8 +518,9 @@ static NodePointer mangleSILDifferentiabilityWitnessAsNode(
return diffWitnessNode;
}

std::string ASTMangler::mangleSILDifferentiabilityWitness(
StringRef originalName, DifferentiabilityKind kind, AutoDiffConfig config) {
std::string ASTMangler::mangleSILDifferentiabilityWitness(StringRef originalName,
DifferentiabilityKind kind,
const AutoDiffConfig &config) {
// If the original name was a mangled name, differentiability witnesses must
// be mangled as node because they contain generic signatures which may repeat
// entities in the original function name. Mangling as node will make sure the
Expand All @@ -545,7 +546,8 @@ std::string ASTMangler::mangleSILDifferentiabilityWitness(

std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
AutoDiffGeneratedDeclarationKind declKind, StringRef origFnName,
unsigned bbId, AutoDiffLinearMapKind linearMapKind, AutoDiffConfig config) {
unsigned bbId, AutoDiffLinearMapKind linearMapKind,
const AutoDiffConfig &config) {
beginManglingWithoutPrefix();

Buffer << "_AD__" << origFnName << "_bb" + std::to_string(bbId);
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7378,7 +7378,7 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
}

void AbstractFunctionDecl::addDerivativeFunctionConfiguration(
AutoDiffConfig config) {
const AutoDiffConfig &config) {
prepareDerivativeFunctionConfigurations();
DerivativeFunctionConfigs->insert(config);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
/// - The last result in the returned pullback.
SILFunction *getOrCreateCustomDerivativeThunk(
AbstractFunctionDecl *originalAFD, SILFunction *originalFn,
SILFunction *customDerivativeFn, AutoDiffConfig config,
SILFunction *customDerivativeFn, const AutoDiffConfig &config,
AutoDiffDerivativeFunctionKind kind);

/// Get or create a derivative function vtable entry thunk for the given
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3912,7 +3912,7 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(

SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
AbstractFunctionDecl *originalAFD, SILFunction *originalFn,
SILFunction *customDerivativeFn, AutoDiffConfig config,
SILFunction *customDerivativeFn, const AutoDiffConfig &config,
AutoDiffDerivativeFunctionKind kind) {
auto customDerivativeFnTy = customDerivativeFn->getLoweredFunctionType();
auto *thunkGenericEnv = customDerivativeFnTy->getSubstGenericSignature()
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Differentiation/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void collectAllActualResultsInTypeOrder(
}

void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, AutoDiffConfig parentConfig,
ApplyInst *ai, const AutoDiffConfig &parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices) {
Expand Down Expand Up @@ -452,7 +452,7 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *&minimalASTParameterIndices) {
Optional<AutoDiffConfig> minimalConfig = None;
auto configs = original->getDerivativeFunctionConfigurations();
for (auto config : configs) {
for (auto &config : configs) {
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
config.parameterIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ class JVPCloner::Implementation final
///
/// Original: y = apply f(x0, x1, ...)
/// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...)
void emitTangentForApplyInst(ApplyInst *ai, AutoDiffConfig applyConfig,
void emitTangentForApplyInst(ApplyInst *ai, const AutoDiffConfig &applyConfig,
CanSILFunctionType originalDifferentialType) {
assert(differentialInfo.shouldDifferentiateApplySite(ai));
auto *bb = ai->getParent();
Expand Down Expand Up @@ -1393,7 +1393,7 @@ static SubstitutionMap getSubstitutionMap(SILFunction *original,
/// and JVP generic signature.
static const DifferentiableActivityInfo &
getActivityInfo(ADContext &context, SILFunction *original,
AutoDiffConfig config, SILFunction *jvp) {
const AutoDiffConfig &config, SILFunction *jvp) {
// Get activity info of the original function.
auto &passManager = context.getPassManager();
auto *activityAnalysis =
Expand Down
5 changes: 2 additions & 3 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static GenericParamList *cloneGenericParameters(ASTContext &ctx,

LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *derivative,
AutoDiffConfig config,
const AutoDiffConfig &config,
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo)
: kind(kind), original(original), derivative(derivative),
Expand Down Expand Up @@ -313,8 +313,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai) {
auto *results = IndexSubset::get(original->getASTContext(), numResults,
activeResultIndices);
// Create autodiff indices for the `apply` instruction.
AutoDiffConfig
applyConfig(parameters, results);
AutoDiffConfig applyConfig(parameters, results);

// Check for non-differentiable original function type.
auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class PullbackCloner::Implementation final
}
DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
AutoDiffConfig getConfig() const { return vjpCloner.getConfig(); }
const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); }
const DifferentiableActivityInfo &getActivityInfo() const {
return vjpCloner.getActivityInfo();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/SILOptimizer/Differentiation/Thunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,8 @@ getOrCreateSubsetParametersThunkForLinearMap(
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
AutoDiffConfig actualConfig, ADContext &adContext) {
AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig,
const AutoDiffConfig &actualConfig, ADContext &adContext) {
LLVM_DEBUG(getADDebugStream()
<< "Getting a subset parameters thunk for derivative function "
<< derivativeFn << " of the original function " << origFnOperand
Expand Down
6 changes: 3 additions & 3 deletions lib/SILOptimizer/Differentiation/VJPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class VJPCloner::Implementation final

ASTContext &getASTContext() const { return vjp->getASTContext(); }
SILModule &getModule() const { return vjp->getModule(); }
AutoDiffConfig getConfig() const {
const AutoDiffConfig &getConfig() const {
return witness->getConfig();
}

Expand Down Expand Up @@ -757,7 +757,7 @@ static SubstitutionMap getSubstitutionMap(SILFunction *original,
/// and VJP generic signature.
static const DifferentiableActivityInfo &
getActivityInfoHelper(ADContext &context, SILFunction *original,
AutoDiffConfig config, SILFunction *vjp) {
const AutoDiffConfig &config, SILFunction *vjp) {
// Get activity info of the original function.
auto &passManager = context.getPassManager();
auto *activityAnalysis =
Expand Down Expand Up @@ -805,7 +805,7 @@ SILFunction &VJPCloner::getPullback() const { return *impl.pullback; }
SILDifferentiabilityWitness *VJPCloner::getWitness() const {
return impl.witness;
}
AutoDiffConfig VJPCloner::getConfig() const {
const AutoDiffConfig &VJPCloner::getConfig() const {
return impl.getConfig();
}
DifferentiationInvoker VJPCloner::getInvoker() const { return impl.invoker; }
Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ static SILValue reapplyFunctionConversion(
static Optional<std::pair<SILValue, AutoDiffConfig>>
emitDerivativeFunctionReference(
DifferentiationTransformer &transformer, SILBuilder &builder,
AutoDiffConfig desiredConfig, AutoDiffDerivativeFunctionKind kind,
const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
SILValue original, DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
ADContext &context = transformer.getContext();
Expand Down
8 changes: 4 additions & 4 deletions lib/SILOptimizer/Utils/DifferentiationMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static NodePointer mangleGenericSignatureAsNode(GenericSignature sig,

static NodePointer mangleAutoDiffFunctionAsNode(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
AutoDiffConfig config, Demangler &demangler) {
const AutoDiffConfig &config, Demangler &demangler) {
assert(isMangledName(originalName));
auto demangledOrig = demangler.demangleSymbol(originalName);
assert(demangledOrig && "Should only be called when the original "
Expand Down Expand Up @@ -75,7 +75,7 @@ static NodePointer mangleAutoDiffFunctionAsNode(

std::string DifferentiationMangler::mangleAutoDiffFunction(
StringRef originalName, Demangle::AutoDiffFunctionKind kind,
AutoDiffConfig config) {
const AutoDiffConfig &config) {
// If the original function is mangled, mangle the tree.
if (isMangledName(originalName)) {
Demangler demangler;
Expand All @@ -94,15 +94,15 @@ std::string DifferentiationMangler::mangleAutoDiffFunction(
// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleDerivativeFunction(
StringRef originalName, AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config) {
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}

// Returns the mangled name for a derivative function of the given kind.
std::string DifferentiationMangler::mangleLinearMap(
StringRef originalName, AutoDiffLinearMapKind kind,
AutoDiffConfig config) {
const AutoDiffConfig &config) {
return mangleAutoDiffFunction(
originalName, getAutoDiffFunctionKind(kind), config);
}
Expand Down
Loading