Skip to content

Commit 3cbc148

Browse files
author
Eugene Burmako
committed
Merge remote-tracking branch 'origin/tensorflow' into tensorflow-merge
2 parents ccb2260 + f0b1c0b commit 3cbc148

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+7038
-5520
lines changed

include/swift/AST/ASTContext.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ namespace swift {
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112112
class IndexSubset;
113+
// SWIFT_ENABLE_TENSORFLOW
114+
struct AutoDiffConfig;
113115
class VectorSpace;
114116
class DifferentiableAttr;
117+
// SWIFT_ENABLE_TENSORFLOW END
115118

116119
enum class KnownProtocolKind : uint8_t;
117120

@@ -711,6 +714,21 @@ class ASTContext final {
711714
unsigned previousGeneration,
712715
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods);
713716

717+
// SWIFT_ENABLE_TENSORFLOW
718+
/// Load derivative function configurations for the given
719+
/// AbstractFunctionDecl.
720+
///
721+
/// \param originalAFD The declaration whose derivative function
722+
/// configurations should be loaded.
723+
///
724+
/// \param previousGeneration The previous generation number. The AST already
725+
/// contains derivative function configurations loaded from any generation up
726+
/// to and including this one.
727+
void loadDerivativeFunctionConfigurations(
728+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
729+
llvm::SetVector<AutoDiffConfig> &results);
730+
// SWIFT_ENABLE_TENSORFLOW END
731+
714732
/// Retrieve the Clang module loader for this ASTContext.
715733
///
716734
/// If there is no Clang module loader, returns a null pointer.

include/swift/AST/Decl.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5698,6 +5698,30 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
56985698
private:
56995699
ParameterList *Params;
57005700

5701+
// SWIFT_ENABLE_TENSORFLOW
5702+
private:
5703+
/// The generation at which we last loaded derivative function configurations.
5704+
unsigned DerivativeFunctionConfigGeneration = 0;
5705+
/// Prepare to traverse the list of derivative function configurations.
5706+
void prepareDerivativeFunctionConfigurations();
5707+
5708+
/// A uniqued list of derivative function configurations.
5709+
/// - `@differentiable` and `@derivative` attribute type-checking is
5710+
/// responsible for populating derivative function configurations specified
5711+
/// in the current module.
5712+
/// - Module loading is responsible for populating derivative function
5713+
/// configurations from imported modules.
5714+
struct DerivativeFunctionConfigurationList;
5715+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5716+
5717+
public:
5718+
/// Get all derivative function configurations.
5719+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
5720+
5721+
/// Add the given derivative function configuration.
5722+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
5723+
// SWIFT_ENABLE_TENSORFLOW END
5724+
57015725
protected:
57025726
// If a function has a body at all, we have either a parsed body AST node or
57035727
// we have saved the end location of the unparsed body.

include/swift/AST/DiagnosticsSIL.def

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,12 @@ NOTE(autodiff_expression_not_differentiable_note,none,
467467
NOTE(autodiff_external_nondifferentiable_function,none,
468468
"cannot differentiate functions that have not been marked "
469469
"'@differentiable' and that are defined in other files", ())
470+
NOTE(autodiff_private_derivative_from_fragile,none,
471+
"differentiated functions in "
472+
"%select{'@inlinable' functions|default arguments}0 must be marked "
473+
"'@differentiable' or have a public '@derivative'"
474+
"%select{|; this is not possible with a closure, make a top-level "
475+
"function instead}1", (unsigned, bool))
470476
NOTE(autodiff_nondifferentiable_argument,none,
471477
"cannot differentiate through a non-differentiable argument; do you want "
472478
"to use 'withoutDerivative(at:)'?", ())

include/swift/AST/ModuleLoader.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class DependencyCollector;
3434

3535
namespace swift {
3636

37+
// SWIFT_ENABLE_TENSORFLOW
38+
struct AutoDiffConfig;
39+
// SWIFT_ENABLE_TENSORFLOW END
3740
class AbstractFunctionDecl;
3841
class ClangImporterOptions;
3942
class ClassDecl;
@@ -151,6 +154,25 @@ class ModuleLoader {
151154
unsigned previousGeneration,
152155
llvm::TinyPtrVector<AbstractFunctionDecl *> &methods) = 0;
153156

157+
// SWIFT_ENABLE_TENSORFLOW
158+
/// Load derivative function configurations for the given
159+
/// AbstractFunctionDecl.
160+
///
161+
/// \param originalAFD The declaration whose derivative function
162+
/// configurations should be loaded.
163+
///
164+
/// \param previousGeneration The previous generation number. The AST already
165+
/// contains derivative function configurations loaded from any generation up
166+
/// to and including this one.
167+
///
168+
/// \param results The result list of derivative function configurations.
169+
/// This list will be extended with any methods found in subsequent
170+
/// generations.
171+
virtual void loadDerivativeFunctionConfigurations(
172+
AbstractFunctionDecl *originalAFD, unsigned previousGeneration,
173+
llvm::SetVector<AutoDiffConfig> &results) {};
174+
// SWIFT_ENABLE_TENSORFLOW END
175+
154176
/// Verify all modules loaded by this loader.
155177
virtual void verifyAllModules() { }
156178
};

include/swift/AST/Types.h

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4303,12 +4303,82 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
43034303

43044304
CanSILFunctionType getWithoutDifferentiability();
43054305

4306-
/// Returns the type of the derivative function.
4306+
/// Returns the type of the derivative function for the given parameter
4307+
/// indices, result index, derivative function kind, derivative function
4308+
/// generic signature (optional), and other auxiliary parameters.
4309+
///
4310+
/// Preconditions:
4311+
/// - Parameters corresponding to parameter indices must conform to
4312+
/// `Differentiable`.
4313+
/// - The result corresponding to the result index must conform to
4314+
/// `Differentiable`.
4315+
///
4316+
/// Typing rules, given:
4317+
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
4318+
///
4319+
/// Terminology:
4320+
/// - The derivative of a `Differentiable`-conforming type has the
4321+
/// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan`
4322+
/// below.
4323+
/// - "wrt" parameters refers to parameters indicated by the parameter
4324+
/// indices.
4325+
/// - "wrt" result refers to the result indicated by the result index.
4326+
///
4327+
/// JVP derivative type:
4328+
/// - Takes original parameters.
4329+
/// - Returns original results, followed by a differential function, which
4330+
/// takes "wrt" parameter derivatives and returns a "wrt" result derivative.
4331+
///
4332+
/// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan)
4333+
/// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~
4334+
/// original results | derivatives wrt params | derivative wrt result
4335+
///
4336+
/// VJP derivative type:
4337+
/// - Takes original parameters.
4338+
/// - Returns original results, followed by a pullback function, which
4339+
/// takes a "wrt" result derivative and returns "wrt" parameter derivatives.
4340+
///
4341+
/// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...))
4342+
/// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~
4343+
/// original results | derivative wrt result | derivatives wrt params
4344+
///
4345+
/// The JVP/VJP generic signature is a "constrained" version of the given
4346+
/// `derivativeFunctionGenericSignature` if specified. Otherwise, it is a
4347+
/// "constrained" version of the original generic signature. A "constrained"
4348+
/// generic signature requires all "wrt" parameters to conform to
4349+
/// `Differentiable`; this is important for correctness.
4350+
///
4351+
/// Other properties of the original function type are copied exactly:
4352+
/// `ExtInfo`, coroutine kind, callee convention, yields, optional error
4353+
/// result, witness method conformance, etc.
4354+
///
4355+
/// Special cases:
4356+
/// - Reabstraction thunks have special derivative type calculation. The
4357+
/// original function-typed last parameter is transformed into a
4358+
/// `@differentiable` function-typed parameter in the derivative type. This
4359+
/// is necessary for the differentiation transform to support reabstraction
4360+
/// thunk differentiation because the function argument is opaque and cannot
4361+
/// be differentiated. Instead, the argument is made `@differentiable` and
4362+
/// reabstraction thunk JVP/VJP callers are responsible for passing a
4363+
/// `@differentiable` function.
4364+
/// - TODO(TF-1036): Investigate more efficient reabstraction thunk
4365+
/// derivative approaches. The last argument can simply be a
4366+
/// corresponding derivative function, instead of a `@differentiable`
4367+
/// function - this is more direct. It may be possible to implement
4368+
/// reabstraction thunk derivatives using "reabstraction thunks for
4369+
/// the original function's derivative", avoiding extra code generation.
4370+
///
4371+
/// Caveats:
4372+
/// - We may support multiple result indices instead of a single result index
4373+
/// eventually. At the SIL level, this enables differentiating wrt multiple
4374+
/// function results. At the Swift level, this enables differentiating wrt
4375+
/// multiple tuple elements for tuple-returning functions.
43074376
CanSILFunctionType getAutoDiffDerivativeFunctionType(
43084377
IndexSubset *parameterIndices, unsigned resultIndex,
43094378
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
43104379
LookupConformanceFn lookupConformance,
4311-
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
4380+
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
4381+
bool isReabstractionThunk = false);
43124382

43134383
/// Returns the type of the transpose function.
43144384
CanSILFunctionType getAutoDiffTransposeFunctionType(

include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class DifferentiableActivityInfo {
131131
SmallVector<SmallDenseSet<SILValue>, 4> usefulValueSets;
132132

133133
/// The original function.
134-
SILFunction &getFunction();
134+
SILFunction &getFunction() const;
135135

136136
/// Returns true if the given SILValue has a tangent space.
137137
bool hasTangentSpace(SILValue value) {
@@ -206,6 +206,14 @@ class DifferentiableActivityInfo {
206206
/// Returns the activity of the given value for the given `SILAutoDiffIndices`
207207
/// (parameter indices and result index).
208208
Activity getActivity(SILValue value, const SILAutoDiffIndices &indices) const;
209+
210+
/// Prints activity information for the `indices` of the given `value`.
211+
void dump(SILValue value, const SILAutoDiffIndices &indices,
212+
llvm::raw_ostream &s = llvm::dbgs()) const;
213+
214+
/// Prints activity information for the given `indices`.
215+
void dump(SILAutoDiffIndices indices,
216+
llvm::raw_ostream &s = llvm::dbgs()) const;
209217
};
210218

211219
class DifferentiableActivityCollection {

include/swift/SILOptimizer/Utils/Differentiation/ADContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ class ADContext {
136136
}
137137

138138
/// Adds the given `differentiable_function` instruction to the worklist.
139-
void addDifferentiableFunctionInst(DifferentiableFunctionInst* dfi) {
139+
void
140+
addDifferentiableFunctionInstToWorklist(DifferentiableFunctionInst *dfi) {
140141
differentiableFunctionInsts.push_back(dfi);
141142
}
142143

0 commit comments

Comments
 (0)