-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff upstream] Add differentiable_function
canonicalization.
#30818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
#include "swift/SIL/SILDifferentiabilityWitness.h" | ||
#include "swift/SIL/SILFunction.h" | ||
#include "swift/SIL/SILModule.h" | ||
#include "swift/SIL/TypeSubstCloner.h" | ||
|
||
namespace swift { | ||
|
||
|
@@ -33,6 +34,12 @@ namespace autodiff { | |
/// This is being used to print short debug messages within the AD pass. | ||
raw_ostream &getADDebugStream(); | ||
|
||
/// Given a function call site, gathers all of its actual results (both direct | ||
/// and indirect) in an order defined by its result type. | ||
void collectAllActualResultsInTypeOrder( | ||
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults, | ||
SmallVectorImpl<SILValue> &results); | ||
|
||
/// Returns the underlying instruction for the given SILValue, if it exists, | ||
/// peering through function conversion instructions. | ||
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) { | ||
|
@@ -51,6 +58,70 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) { | |
return nullptr; | ||
} | ||
|
||
/// Given a range of elements, joins these into a single value. If there's | ||
/// exactly one element, returns that element. Otherwise, creates a tuple using | ||
/// a `tuple` instruction. | ||
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, | ||
SILLocation loc); | ||
|
||
/// Given a value, extracts all elements to `results` from this value if it has | ||
/// a tuple type. Otherwise, add this value directly to `results`. | ||
void extractAllElements(SILValue value, SILBuilder &builder, | ||
SmallVectorImpl<SILValue> &results); | ||
|
||
/// Emit a zero value into the given buffer access by calling | ||
/// `AdditiveArithmetic.zero`. The given type must conform to | ||
/// `AdditiveArithmetic`. | ||
void emitZeroIntoBuffer(SILBuilder &builder, CanType type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix these APIs to not take a SILBuilder. Instead, it should take an insertion point and a SILBuilderContext. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change seems important for correctness. It can probably help clean up this code: I started looking into this but ran into some difficulties. The user above wants to insert at the end of a basic block (
|
||
SILValue bufferAccess, SILLocation loc); | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Utilities for looking up derivatives of functions | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Returns a differentiability witness (definition or declaration) exactly | ||
/// matching the specified indices. If none are found in the given `module`, | ||
/// returns `nullptr`. | ||
/// | ||
/// \param parameterIndices must be lowered to SIL. | ||
/// \param resultIndices must be lowered to SIL. | ||
SILDifferentiabilityWitness * | ||
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, | ||
IndexSubset *parameterIndices, | ||
IndexSubset *resultIndices); | ||
|
||
/// Finds the derivative configuration (from `@differentiable` and | ||
/// `@derivative` attributes) for `original` whose parameter indices are a | ||
/// minimal superset of the specified AST parameter indices. Returns `None` if | ||
/// no such configuration is found. | ||
/// | ||
/// \param parameterIndices must be lowered to SIL. | ||
/// \param minimalASTParameterIndices is an output parameter that is set to the | ||
/// AST indices of the minimal configuration, or to `nullptr` if no such | ||
/// configuration exists. | ||
Optional<AutoDiffConfig> | ||
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, | ||
IndexSubset *parameterIndices, | ||
IndexSubset *&minimalASTParameterIndices); | ||
|
||
/// Returns a differentiability witness for `original` whose parameter indices | ||
/// are a minimal superset of the specified parameter indices and whose result | ||
/// indices match the given result indices, out of all | ||
/// differentiability witnesses that come from AST "@differentiable" or | ||
/// "@differentiating" attributes. | ||
/// | ||
/// This function never creates new differentiability witness definitions. | ||
/// However, this function may create new differentiability witness declarations | ||
/// referring to definitions in other modules when these witnesses have not yet | ||
/// been declared in the current module. | ||
/// | ||
/// \param module is the SILModule in which to get or create the witnesses. | ||
/// \param parameterIndices must be lowered to SIL. | ||
/// \param resultIndices must be lowered to SIL. | ||
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( | ||
SILModule &module, SILFunction *original, IndexSubset *parameterIndices, | ||
IndexSubset *resultIndices); | ||
|
||
} // end namespace autodiff | ||
|
||
/// Creates arguments in the entry block based on the function type. | ||
|
@@ -85,6 +156,34 @@ inline void createEntryArguments(SILFunction *f) { | |
} | ||
} | ||
|
||
/// Cloner that remaps types using the target function's generic environment. | ||
class BasicTypeSubstCloner final | ||
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> { | ||
|
||
static SubstitutionMap getSubstitutionMap(SILFunction *target) { | ||
if (auto *targetGenEnv = target->getGenericEnvironment()) | ||
return targetGenEnv->getForwardingSubstitutionMap(); | ||
return SubstitutionMap(); | ||
} | ||
|
||
public: | ||
explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target) | ||
: TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {} | ||
|
||
void postProcess(SILInstruction *orig, SILInstruction *cloned) { | ||
SILClonerWithScopes::postProcess(orig, cloned); | ||
} | ||
|
||
void cloneFunction() { | ||
auto &newFunction = Builder.getFunction(); | ||
auto *entry = newFunction.createBasicBlock(); | ||
createEntryArguments(&newFunction); | ||
SmallVector<SILValue, 8> entryArguments(newFunction.getArguments().begin(), | ||
newFunction.getArguments().end()); | ||
cloneFunctionBody(&Original, entry, entryArguments); | ||
} | ||
}; | ||
|
||
} // end namespace swift | ||
|
||
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_COMMON_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
//===--- Thunk.h - Automatic differentiation thunks -----------*- C++ -*---===// | ||
// | ||
// This source file is part of the Swift.org open source project | ||
// | ||
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors | ||
// Licensed under Apache License v2.0 with Runtime Library Exception | ||
// | ||
// See https://swift.org/LICENSE.txt for license information | ||
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Automatic differentiation thunk generation utilities. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H | ||
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H | ||
|
||
#include "swift/AST/AutoDiff.h" | ||
#include "swift/Basic/LLVM.h" | ||
#include "swift/SIL/SILBuilder.h" | ||
|
||
namespace swift { | ||
|
||
class SILOptFunctionBuilder; | ||
class SILModule; | ||
class SILLocation; | ||
class SILValue; | ||
class OpenedArchetypeType; | ||
class GenericEnvironment; | ||
class SubstitutionMap; | ||
class ArchetypeType; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Helpers | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace autodiff { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Thunk helpers | ||
//===----------------------------------------------------------------------===// | ||
// These helpers are copied/adapted from SILGen. They should be refactored and | ||
// moved to a shared location. | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// The thunk kinds used in the differentiation transform. | ||
enum class DifferentiationThunkKind { | ||
/// A reabstraction thunk. | ||
/// | ||
/// Reabstraction thunks transform a function-typed value to another one with | ||
/// different parameter/result abstraction patterns. This is identical to the | ||
/// thunks generated by SILGen. | ||
Reabstraction, | ||
|
||
/// An index subset thunk. | ||
/// | ||
/// An index subset thunk is used transform JVP/VJPs into a version that is | ||
/// "wrt" fewer differentiation parameters. | ||
/// - Differentials of thunked JVPs use zero for non-requested differentiation | ||
/// parameters. | ||
/// - Pullbacks of thunked VJPs discard results for non-requested | ||
/// differentiation parameters. | ||
IndexSubset | ||
}; | ||
|
||
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig, | ||
OpenedArchetypeType *openedExistential, | ||
GenericEnvironment *&genericEnv, | ||
SubstitutionMap &contextSubs, | ||
SubstitutionMap &interfaceSubs, | ||
ArchetypeType *&newArchetype); | ||
|
||
/// Build the type of a function transformation thunk. | ||
CanSILFunctionType buildThunkType(SILFunction *fn, | ||
CanSILFunctionType &sourceType, | ||
CanSILFunctionType &expectedType, | ||
GenericEnvironment *&genericEnv, | ||
SubstitutionMap &interfaceSubs, | ||
bool withoutActuallyEscaping, | ||
DifferentiationThunkKind thunkKind); | ||
|
||
/// Get or create a derivative function parameter index subset thunk from | ||
/// `actualIndices` to `desiredIndices` for the given associated function | ||
/// value and original function operand. Returns a pair of the parameter | ||
/// index subset thunk and its interface substitution map (used to partially | ||
/// apply the thunk). | ||
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear | ||
/// map returned by the derivative function. | ||
std::pair<SILFunction *, SubstitutionMap> | ||
getOrCreateSubsetParametersThunkForDerivativeFunction( | ||
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn, | ||
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, | ||
SILAutoDiffIndices actualIndices); | ||
|
||
/// Get or create a derivative function parameter index subset thunk from | ||
/// `actualIndices` to `desiredIndices` for the given associated function | ||
/// value and original function operand. Returns a pair of the parameter | ||
/// index subset thunk and its interface substitution map (used to partially | ||
/// apply the thunk). | ||
std::pair<SILFunction *, SubstitutionMap> | ||
getOrCreateSubsetParametersThunkForLinearMap( | ||
SILOptFunctionBuilder &fb, SILFunction *assocFn, | ||
CanSILFunctionType origFnType, CanSILFunctionType linearMapType, | ||
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, | ||
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); | ||
|
||
} // end namespace autodiff | ||
|
||
} // end namespace swift | ||
|
||
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_THUNK_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a number of problems with this. I am pretty sure we already have something like this.