Skip to content

Commit bb0aa1c

Browse files
authored
Merge pull request #30821 from dan-zheng/differentiation-transform
[AutoDiff upstream] Add reverse-mode automatic differentiation.
2 parents bd160ac + 52374bf commit bb0aa1c

26 files changed

+5994
-10
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,25 @@ NOTE(autodiff_member_subset_indices_not_differentiable,none,
510510
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
511511
"cannot convert a direct method reference to a '@differentiable' "
512512
"function; use an explicit closure instead", ())
513+
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
514+
"cannot differentiate through multiple results", ())
515+
// TODO(TF-1149): Remove this diagnostic.
516+
NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
517+
"cannot yet differentiate value whose type %0 has a compile-time known "
518+
"size, but whose 'TangentVector' contains stored properties of unknown "
519+
"size; consider modifying %1 to use fewer generic parameters in stored "
520+
"properties", (Type, Type))
521+
NOTE(autodiff_enums_unsupported,none,
522+
"differentiating enum values is not yet supported", ())
523+
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
524+
"property cannot be differentiated because '%0.TangentVector' does not "
525+
"have a member named '%1'", (StringRef, StringRef))
526+
NOTE(autodiff_coroutines_not_supported,none,
527+
"differentiation of coroutine calls is not yet supported", ())
528+
NOTE(autodiff_cannot_differentiate_writes_to_global_variables,none,
529+
"cannot differentiate writes to global variables", ())
530+
NOTE(autodiff_cannot_differentiate_writes_to_mutable_captures,none,
531+
"cannot differentiate writes to mutable captures", ())
513532

514533
ERROR(non_physical_addressof,none,
515534
"addressof only works with purely physical lvalues; "

include/swift/AST/SourceFile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ class SourceFile final : public FileUnit {
394394
void cacheVisibleDecls(SmallVectorImpl<ValueDecl *> &&globals) const;
395395
const SmallVectorImpl<ValueDecl *> &getCachedVisibleDecls() const;
396396

397+
void addVisibleDecl(ValueDecl *decl);
398+
397399
virtual void lookupValue(DeclName name, NLKind lookupKind,
398400
SmallVectorImpl<ValueDecl*> &result) const override;
399401

include/swift/SIL/ApplySite.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,17 @@ class FullApplySite : public ApplySite {
510510
return getArguments().slice(getNumIndirectSILResults());
511511
}
512512

513+
InoutArgumentRange getInoutArguments() const {
514+
switch (getKind()) {
515+
case FullApplySiteKind::ApplyInst:
516+
return cast<ApplyInst>(getInstruction())->getInoutArguments();
517+
case FullApplySiteKind::TryApplyInst:
518+
return cast<TryApplyInst>(getInstruction())->getInoutArguments();
519+
case FullApplySiteKind::BeginApplyInst:
520+
return cast<BeginApplyInst>(getInstruction())->getInoutArguments();
521+
}
522+
}
523+
513524
/// Returns true if \p op is the callee operand of this apply site
514525
/// and not an argument operand.
515526
bool isCalleeOperand(const Operand &op) const {

include/swift/SIL/SILCloner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ class SILCloner : protected SILInstructionVisitor<ImplClass> {
4747
TypeSubstitutionMap OpenedExistentialSubs;
4848
SILOpenedArchetypesTracker OpenedArchetypesTracker;
4949

50-
private:
51-
/// MARK: Private state hidden from CRTP extensions.
52-
5350
// The old-to-new value map.
5451
llvm::DenseMap<SILValue, SILValue> ValueMap;
5552

5653
/// The old-to-new block map. Some entries may be premapped with original
5754
/// blocks.
5855
llvm::DenseMap<SILBasicBlock*, SILBasicBlock*> BBMap;
5956

57+
private:
58+
/// MARK: Private state hidden from CRTP extensions.
59+
6060
// The original blocks in DFS preorder. All blocks in this list are mapped.
6161
// After cloning, this represents the entire cloned CFG.
6262
//

include/swift/SILOptimizer/Analysis/Analysis.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ ANALYSIS(Caller)
3131
ANALYSIS(ClassHierarchy)
3232
ANALYSIS(ClosureScope)
3333
ANALYSIS(Destructor)
34+
ANALYSIS(DifferentiableActivity)
3435
ANALYSIS(Dominance)
3536
ANALYSIS(EpilogueARC)
3637
ANALYSIS(Escape)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//===--- DifferentiableActivityAnalysis.h ---------------------*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file implements activity analysis: a dataflow analysis used for
14+
// automatic differentiation.
15+
//
16+
// In many real situations, the end-users of AD need only the derivatives of
17+
// some selected outputs of `P` with respect to some selected inputs of `P`.
18+
// Whatever the differentiation mode (tangent, reverse,...), these restrictions
19+
// allow the AD tool to produce a much more efficient differentiated program.
20+
// Essentially, fixing some inputs and neglecting some outputs allows AD to
21+
// just forget about several intermediate differentiated variables.
22+
//
23+
// Activity analysis is the specific analysis that detects these situations,
24+
// therefore allowing for a better differentiated code. Activity analysis is
25+
// present in all transformation-based AD tools.
26+
//
27+
// To begin with, the end-user specifies that only some output variables (the
28+
// “dependent”) must be differentiated with respect to only some input
29+
// variables (the “independent”). We say that variable `y` depends on `x` when
30+
// the derivative of `y` with respect to `x` is not trivially null. We say that
31+
// a variable is “varied” if it depends on at least one independent. Conversely
32+
// we say that a variable is “useful” if at least one dependent depends on it.
33+
// Finally, we say that a variable is “active” if it is at the same time varied
34+
// and useful. In the special case of the tangent mode, it is easy to check
35+
// that when variable `v` is not varied at some place in the program, then its
36+
// derivative `v̇` at this place is certainly null. Conversely when variable `v`
37+
// is not useful, then whatever the value of `v̇`, this value does not matter
38+
// for the final result. Symmetric reasoning applies for the reverse mode of
39+
// AD: observing that differentiated variables go upstream, we see that a
40+
// useless variable has a null derivative, in other words the partial
41+
// derivative of the output with respect to this variable is null. Conversely
42+
// when variable `v` is not varied, then whatever the value of `v`, this value
43+
// does not matter for the final result.
44+
//
45+
// Reference:
46+
// Laurent Hascoët. Automatic Differentiation by Program Transformation. 2007.
47+
48+
#ifndef SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_
49+
#define SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_
50+
51+
#include "swift/AST/GenericEnvironment.h"
52+
#include "swift/AST/GenericSignatureBuilder.h"
53+
#include "swift/SIL/SILFunction.h"
54+
#include "swift/SIL/SILModule.h"
55+
#include "swift/SIL/SILValue.h"
56+
#include "swift/SILOptimizer/Analysis/Analysis.h"
57+
#include "llvm/ADT/DenseMap.h"
58+
#include "llvm/ADT/DenseSet.h"
59+
60+
using llvm::SmallDenseMap;
61+
using llvm::SmallDenseSet;
62+
63+
namespace swift {
64+
65+
class DominanceAnalysis;
66+
class PostDominanceAnalysis;
67+
class DominanceInfo;
68+
class PostDominanceInfo;
69+
class SILFunciton;
70+
71+
class DifferentiableActivityCollection;
72+
class DifferentiableActivityAnalysis
73+
: public FunctionAnalysisBase<DifferentiableActivityCollection> {
74+
private:
75+
DominanceAnalysis *dominanceAnalysis = nullptr;
76+
PostDominanceAnalysis *postDominanceAnalysis = nullptr;
77+
78+
public:
79+
explicit DifferentiableActivityAnalysis()
80+
: FunctionAnalysisBase(SILAnalysisKind::DifferentiableActivity) {}
81+
82+
static bool classof(const SILAnalysis *s) {
83+
return s->getKind() == SILAnalysisKind::DifferentiableActivity;
84+
}
85+
86+
virtual bool shouldInvalidate(SILAnalysis::InvalidationKind k) override {
87+
return k & InvalidationKind::Everything;
88+
}
89+
90+
virtual std::unique_ptr<DifferentiableActivityCollection>
91+
newFunctionAnalysis(SILFunction *f) override;
92+
93+
virtual void initialize(SILPassManager *pm) override;
94+
};
95+
96+
/// Represents the differentiation activity associated with a SIL value.
97+
enum class ActivityFlags : unsigned {
98+
/// The value depends on a function parameter.
99+
Varied = 1 << 1,
100+
/// The value contributes to a result.
101+
Useful = 1 << 2,
102+
/// The value is both varied and useful.
103+
Active = Varied | Useful,
104+
};
105+
106+
using Activity = OptionSet<ActivityFlags>;
107+
108+
/// Result of activity analysis on a function. Accepts queries for whether a
109+
/// value is "varied", "useful" or "active" against certain differentiation
110+
/// indices.
111+
class DifferentiableActivityInfo {
112+
private:
113+
DifferentiableActivityCollection &parent;
114+
115+
/// The derivative generic signature.
116+
GenericSignature derivativeGenericSignature;
117+
118+
/// Input values, i.e. parameters (both direct and indirect).
119+
SmallVector<SILValue, 4> inputValues;
120+
/// Output values, i.e. individual values (not the final tuple) being returned
121+
/// by the `return` instruction.
122+
SmallVector<SILValue, 4> outputValues;
123+
124+
/// The set of varied variables, indexed by the corresponding independent
125+
/// value (input) index.
126+
SmallVector<SmallDenseSet<SILValue>, 4> variedValueSets;
127+
/// The set of useful variables, indexed by the corresponding dependent value
128+
/// (output) index.
129+
SmallVector<SmallDenseSet<SILValue>, 4> usefulValueSets;
130+
131+
/// The original function.
132+
SILFunction &getFunction() const;
133+
134+
/// Returns true if the given SILValue has a tangent space.
135+
bool hasTangentSpace(SILValue value) {
136+
auto type = value->getType().getASTType();
137+
// Remap archetypes in the derivative generic signature, if it exists.
138+
if (derivativeGenericSignature && type->hasArchetype()) {
139+
type = derivativeGenericSignature->getCanonicalTypeInContext(
140+
type->mapTypeOutOfContext());
141+
}
142+
// Look up conformance in the current module.
143+
auto lookupConformance =
144+
LookUpConformanceInModule(getFunction().getModule().getSwiftModule());
145+
return type->getAutoDiffTangentSpace(lookupConformance).hasValue();
146+
}
147+
148+
/// Perform analysis and populate variedness and usefulness sets.
149+
void analyze(DominanceInfo *di, PostDominanceInfo *pdi);
150+
151+
/// Marks the given value as varied and propagates variedness to users.
152+
void setVariedAndPropagateToUsers(SILValue value,
153+
unsigned independentVariableIndex);
154+
/// Propagates variedness from the given operand to its user's results.
155+
void propagateVaried(Operand *operand, unsigned independentVariableIndex);
156+
/// Marks the given value as varied and recursively propagates variedness
157+
/// inwards (to operands) through projections. Skips `@noDerivative` field
158+
/// projections.
159+
void
160+
propagateVariedInwardsThroughProjections(SILValue value,
161+
unsigned independentVariableIndex);
162+
163+
/// Marks the given value as useful for the given dependent variable index.
164+
void setUseful(SILValue value, unsigned dependentVariableIndex);
165+
/// Marks the given value as useful and recursively propagates usefulness to:
166+
/// - Defining instruction operands, if the value has a defining instruction.
167+
/// - Incoming values, if the value is a basic block argument.
168+
void setUsefulAndPropagateToOperands(SILValue value,
169+
unsigned dependentVariableIndex);
170+
/// Propagates usefulnesss to the operands of the given instruction.
171+
void propagateUseful(SILInstruction *inst, unsigned dependentVariableIndex);
172+
/// Marks the given address or class-typed value as useful and recursively
173+
/// propagates usefulness inwards (to operands) through projections. Skips
174+
/// `@noDerivative` field projections.
175+
void propagateUsefulThroughAddress(SILValue value,
176+
unsigned dependentVariableIndex);
177+
/// If the given value is an `array.uninitialized_intrinsic` application,
178+
/// selectively propagate usefulness through its `RawPointer` result.
179+
void setUsefulThroughArrayInitialization(SILValue value,
180+
unsigned dependentVariableIndex);
181+
182+
public:
183+
explicit DifferentiableActivityInfo(
184+
DifferentiableActivityCollection &parent,
185+
GenericSignature derivativeGenericSignature);
186+
187+
/// Returns true if the given value is varied for the given independent
188+
/// variable index.
189+
bool isVaried(SILValue value, unsigned independentVariableIndex) const;
190+
191+
/// Returns true if the given value is varied for any of the given parameter
192+
/// (independent variable) indices.
193+
bool isVaried(SILValue value, IndexSubset *parameterIndices) const;
194+
195+
/// Returns true if the given value is useful for the given dependent variable
196+
/// index.
197+
bool isUseful(SILValue value, unsigned dependentVariableIndex) const;
198+
199+
/// Returns true if the given value is active for the given
200+
/// `SILAutoDiffIndices` (parameter indices and result index).
201+
bool isActive(SILValue value, const SILAutoDiffIndices &indices) const;
202+
203+
/// Returns the activity of the given value for the given `SILAutoDiffIndices`
204+
/// (parameter indices and result index).
205+
Activity getActivity(SILValue value, const SILAutoDiffIndices &indices) const;
206+
207+
/// Prints activity information for the `indices` of the given `value`.
208+
void dump(SILValue value, const SILAutoDiffIndices &indices,
209+
llvm::raw_ostream &s = llvm::dbgs()) const;
210+
211+
/// Prints activity information for the given `indices`.
212+
void dump(SILAutoDiffIndices indices,
213+
llvm::raw_ostream &s = llvm::dbgs()) const;
214+
};
215+
216+
class DifferentiableActivityCollection {
217+
public:
218+
SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap;
219+
SILFunction &function;
220+
DominanceInfo *domInfo;
221+
PostDominanceInfo *postDomInfo;
222+
223+
DifferentiableActivityInfo &
224+
getActivityInfo(GenericSignature assocGenSig,
225+
AutoDiffDerivativeFunctionKind kind) {
226+
auto activityInfoLookup = activityInfoMap.find(assocGenSig);
227+
if (activityInfoLookup != activityInfoMap.end())
228+
return activityInfoLookup->getSecond();
229+
auto insertion = activityInfoMap.insert(
230+
{assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)});
231+
return insertion.first->getSecond();
232+
}
233+
234+
explicit DifferentiableActivityCollection(SILFunction &f, DominanceInfo *di,
235+
PostDominanceInfo *pdi);
236+
};
237+
238+
} // end namespace swift
239+
240+
#endif // SWIFT_SILOPTIMIZER_ANALYSIS_DIFFERENTIABLEACTIVITYANALYSIS_H_

0 commit comments

Comments
 (0)