Skip to content

[AudoDiff] NFC: Replace 'SILAutoDiffIndices' with 'AutoDiffConfig'. #35079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 30 additions & 51 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,80 +173,59 @@ enum class AutoDiffGeneratedDeclarationKind : uint8_t {
BranchingTraceEnum
};

/// SIL-level automatic differentiation indices. Consists of:
/// - The differentiability parameter indices.
/// - The differentiability result indices.
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
// `AutoDiffConfig` additionally stores a derivative generic signature.
struct SILAutoDiffIndices {
/// The indices of independent parameters to differentiate with respect to.
IndexSubset *parameters;
/// The indices of dependent results to differentiate from.
IndexSubset *results;

/*implicit*/ SILAutoDiffIndices(IndexSubset *parameters, IndexSubset *results)
: parameters(parameters), results(results) {
assert(parameters && "Parameter indices must be non-null");
assert(results && "Result indices must be non-null");
}

bool operator==(const SILAutoDiffIndices &other) const;

bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};
/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature derivativeGenericSignature;

/*implicit*/ AutoDiffConfig(
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenericSignature = GenericSignature())
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

/// Returns true if `parameterIndex` is a differentiability parameter index.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
return parameterIndex < parameterIndices->getCapacity() &&
parameterIndices->contains(parameterIndex);
}

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
/// Returns true if `resultIndex` is a differentiability result index.
bool isWrtResult(unsigned resultIndex) const {
return resultIndex < resultIndices->getCapacity() &&
resultIndices->contains(resultIndex);
}

AutoDiffConfig withGenericSignature(GenericSignature signature) const {
return AutoDiffConfig(parameterIndices, resultIndices, signature);
}

// TODO(SR-13506): Use principled mangling for AD-generated symbols.
std::string mangle() const {
std::string result = "src_";
interleave(
results->getIndices(),
resultIndices->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
result += "_wrt_";
llvm::interleave(
parameters->getIndices(),
parameterIndices->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature derivativeGenericSignature;

/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature derivativeGenericSignature)
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-913): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
const SILAutoDiffIndices &indices) {
indices.print(s);
const AutoDiffConfig &config) {
config.print(s);
return s;
}

Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/IndexSubset.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class IndexSubset : public llvm::FoldingSetNode {
/// Returns the number of bit words used to store the index subset.
// Note: Use `getCapacity()` to get the total index subset capacity.
// This is public only for unit testing
// (in unittests/AST/SILAutoDiffIndices.cpp).
// (in unittests/AST/IndexSubsetTests.cpp).
unsigned getNumBitWords() const {
return numBitWords;
}
Expand Down
5 changes: 0 additions & 5 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ class SILDifferentiabilityWitness
bool isSerialized() const { return IsSerialized; }
const DeclAttribute *getAttribute() const { return Attribute; }

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-893): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;

/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,47 @@ class DifferentiableActivityInfo {
/// (dependent variable) indices.
bool isUseful(SILValue value, IndexSubset *dependentVariableIndices) const;

/// Returns true if the given value is active for the given
/// `SILAutoDiffIndices` (parameter indices and result index).
bool isActive(SILValue value, const SILAutoDiffIndices &indices) const;
/// Returns true if the given value is active for the given parameter indices
/// and result indices.
bool isActive(SILValue value,
IndexSubset *parameterIndices,
IndexSubset *resultIndices) const;

/// Returns true if the given value is active for the given config.
bool isActive(SILValue value, AutoDiffConfig config) const {
return isActive(value, config.parameterIndices, config.resultIndices);
}

/// Returns the activity of the given value for the given config.
Activity getActivity(SILValue value,
IndexSubset *parameterIndices,
IndexSubset *resultIndices) const;

/// Returns the activity of the given value for the given `SILAutoDiffIndices`
/// (parameter indices and result index).
Activity getActivity(SILValue value, const SILAutoDiffIndices &indices) const;
/// Returns the activity of the given value for the given config.
Activity getActivity(SILValue value, AutoDiffConfig config) const {
return getActivity(value, config.parameterIndices, config.resultIndices);
}

/// Prints activity information for the `indices` of the given `value`.
void dump(SILValue value, const SILAutoDiffIndices &indices,
/// Prints activity information for the config of the given value.
void dump(SILValue value,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
llvm::raw_ostream &s = llvm::dbgs()) const;

/// Prints activity information for the given `indices`.
void dump(SILAutoDiffIndices indices,
/// Prints activity information for the config of the given value.
void dump(SILValue value, AutoDiffConfig config,
llvm::raw_ostream &s = llvm::dbgs()) const {
return dump(value, config.parameterIndices, config.resultIndices, s);
}

/// Prints all activity information for the given parameter indices and result
/// indices.
void dump(IndexSubset *parameterIndices, IndexSubset *resultIndices,
llvm::raw_ostream &s = llvm::dbgs()) const;

/// Prints all activity information for the given config.
void dump(AutoDiffConfig config, llvm::raw_ostream &s = llvm::dbgs()) const {
return dump(config.parameterIndices, config.resultIndices, s);
}
};

class DifferentiableActivityCollection {
Expand Down
7 changes: 5 additions & 2 deletions include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ namespace autodiff {

/// Stores `apply` instruction information calculated by VJP generation.
struct NestedApplyInfo {
/// The differentiation indices that are used to differentiate this `apply`
/// The differentiation config that is used to differentiate this `apply`
/// instruction.
SILAutoDiffIndices indices;
AutoDiffConfig config;
/// The original pullback type before reabstraction. `None` if the pullback
/// type is not reabstracted.
Optional<CanSILFunctionType> originalPullbackType;
Expand Down Expand Up @@ -120,6 +120,9 @@ class ADContext {
/// Construct an ADContext for the given module.
explicit ADContext(SILModuleTransform &transform);

// No copying.
ADContext(const ADContext &) = delete;

//--------------------------------------------------------------------------//
// General utilities
//--------------------------------------------------------------------------//
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SILOptimizer/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void collectAllActualResultsInTypeOrder(
/// - The set of minimal parameter and result indices for differentiating the
/// `apply` instruction.
void collectMinimalIndicesForFunctionCall(
ApplyInst *ai, SILAutoDiffIndices parentIndices,
ApplyInst *ai, AutoDiffConfig parentConfig,
const DifferentiableActivityInfo &activityInfo,
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
SmallVectorImpl<unsigned> &resultIndices);
Expand Down
4 changes: 2 additions & 2 deletions include/swift/SILOptimizer/Differentiation/LinearMapInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class LinearMapInfo {
SILLoopInfo *loopInfo;

/// Differentiation indices of the function.
const SILAutoDiffIndices indices;
const AutoDiffConfig config;

/// Mapping from original basic blocks to linear map structs.
llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;
Expand Down Expand Up @@ -149,7 +149,7 @@ class LinearMapInfo {

explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
SILFunction *original, SILFunction *derivative,
SILAutoDiffIndices indices,
AutoDiffConfig config,
const DifferentiableActivityInfo &activityInfo,
SILLoopInfo *loopInfo);

Expand Down
6 changes: 3 additions & 3 deletions include/swift/SILOptimizer/Differentiation/Thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ SILValue reabstractFunction(
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
SILAutoDiffIndices actualIndices);
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
AutoDiffConfig actualConfig);

/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
Expand All @@ -119,7 +119,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
SILOptFunctionBuilder &fb, SILFunction *assocFn,
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig);

} // end namespace autodiff

Expand Down
2 changes: 1 addition & 1 deletion include/swift/SILOptimizer/Differentiation/VJPCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class VJPCloner final {
SILFunction &getVJP() const;
SILFunction &getPullback() const;
SILDifferentiabilityWitness *getWitness() const;
const SILAutoDiffIndices getIndices() const;
AutoDiffConfig getConfig() const;
DifferentiationInvoker getInvoker() const;
LinearMapInfo &getPullbackInfo() const;
SILLoopInfo *getLoopInfo() const;
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
Buffer << "_vjp_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
Buffer << config.mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
Expand All @@ -445,7 +445,7 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
Buffer << "_pullback_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
Buffer << config.mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
Expand Down Expand Up @@ -484,7 +484,7 @@ std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
}
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
Buffer << config.mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
Expand Down
20 changes: 0 additions & 20 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,6 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
llvm_unreachable("invalid derivative kind");
}

void SILAutoDiffIndices::print(llvm::raw_ostream &s) const {
s << "(parameters=(";
interleave(
parameters->getIndices(), [&s](unsigned p) { s << p; },
[&s] { s << ' '; });
s << ") results=(";
interleave(
results->getIndices(), [&s](unsigned p) { s << p; }, [&s] { s << ' '; });
s << "))";
}

void SILAutoDiffIndices::dump() const {
print(llvm::errs());
llvm::errs() << '\n';
}

SILAutoDiffIndices AutoDiffConfig::getSILAutoDiffIndices() const {
return SILAutoDiffIndices(parameterIndices, resultIndices);
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down
4 changes: 0 additions & 4 deletions lib/SIL/IR/SILDifferentiabilityWitness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,3 @@ void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp,
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
return std::make_pair(getOriginalFunction()->getName(), getConfig());
}

SILAutoDiffIndices SILDifferentiabilityWitness::getSILAutoDiffIndices() const {
return getConfig().getSILAutoDiffIndices();
}
38 changes: 22 additions & 16 deletions lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,26 +509,27 @@ bool DifferentiableActivityInfo::isUseful(
}

bool DifferentiableActivityInfo::isActive(
SILValue value, const SILAutoDiffIndices &indices) const {
return isVaried(value, indices.parameters) &&
isUseful(value, indices.results);
SILValue value, IndexSubset *parameterIndices,
IndexSubset *resultIndices) const {
return isVaried(value, parameterIndices) && isUseful(value, resultIndices);
}

Activity DifferentiableActivityInfo::getActivity(
SILValue value, const SILAutoDiffIndices &indices) const {
SILValue value, IndexSubset *parameterIndices,
IndexSubset *resultIndices) const {
Activity activity;
if (isVaried(value, indices.parameters))
if (isVaried(value, parameterIndices))
activity |= ActivityFlags::Varied;
if (isUseful(value, indices.results))
if (isUseful(value, resultIndices))
activity |= ActivityFlags::Useful;
return activity;
}

void DifferentiableActivityInfo::dump(SILValue value,
const SILAutoDiffIndices &indices,
llvm::raw_ostream &s) const {
void DifferentiableActivityInfo::dump(
SILValue value, IndexSubset *parameterIndices, IndexSubset *resultIndices,
llvm::raw_ostream &s) const {
s << '[';
auto activity = getActivity(value, indices);
auto activity = getActivity(value, parameterIndices, resultIndices);
switch (activity.toRaw()) {
case 0:
s << "NONE";
Expand All @@ -546,19 +547,24 @@ void DifferentiableActivityInfo::dump(SILValue value,
s << "] " << value;
}

void DifferentiableActivityInfo::dump(SILAutoDiffIndices indices,
llvm::raw_ostream &s) const {
void DifferentiableActivityInfo::dump(
IndexSubset *parameterIndices, IndexSubset *resultIndices,
llvm::raw_ostream &s) const {
SILFunction &fn = getFunction();
s << "Activity info for " << fn.getName() << " at " << indices << '\n';
s << "Activity info for " << fn.getName() << " at parameter indices (";
llvm::interleaveComma(parameterIndices->getIndices(), s);
s << ") and result indices (";
llvm::interleaveComma(resultIndices->getIndices(), s);
s << "):\n";
for (auto &bb : fn) {
s << "bb" << bb.getDebugID() << ":\n";
for (auto *arg : bb.getArguments())
dump(arg, indices, s);
dump(arg, parameterIndices, resultIndices, s);
for (auto &inst : bb)
for (auto res : inst.getResults())
dump(res, indices, s);
dump(res, parameterIndices, resultIndices, s);
if (std::next(bb.getIterator()) != fn.end())
s << '\n';
}
s << "End activity info for " << fn.getName() << " at " << indices << "\n\n";
s << "End activity info for " << fn.getName() << '\n';
}
Loading