Skip to content

Commit 146c11e

Browse files
authored
[AutoDiff upstream] Add differentiable_function canonicalization. (#30818)
Canonicalizes `differentiable_function` instructions by filling in missing derivative function operands. Derivative function emission rules, based on the original function value: - `function_ref`: look up differentiability witness with the exact or a minimal superset derivative configuration. Emit a `differentiability_witness_function` for the derivative function. - `witness_method`: emit a `witness_method` with the minimal superset derivative configuration for the derivative function. - `class_method`: emit a `class_method` with the minimal superset derivative configuration for the derivative function. If an *actual* emitted derivative function has a superset derivative configuration versus the *desired* derivative configuration, create a "subset parameters thunk" to thunk the actual derivative to the desired type. For `differentiable_function` instructions formed from curry thunk applications: clone the curry thunk (with type `(Self) -> (T, ...) -> U`) and create a new version with type `(Self) -> @differentiable (T, ...) -> U`. Progress towards TF-1211.
1 parent 15f512b commit 146c11e

File tree

13 files changed

+2091
-13
lines changed

13 files changed

+2091
-13
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,42 @@ NOTE(autodiff_control_flow_not_supported,none,
474474
"cannot differentiate unsupported control flow", ())
475475
NOTE(autodiff_missing_return,none,
476476
"missing return for differentiation", ())
477+
NOTE(autodiff_external_nondifferentiable_function,none,
478+
"cannot differentiate functions that have not been marked "
479+
"'@differentiable' and that are defined in other files", ())
480+
NOTE(autodiff_opaque_function_not_differentiable,none,
481+
"opaque non-'@differentiable' function is not differentiable", ())
482+
NOTE(autodiff_private_derivative_from_fragile,none,
483+
"differentiated functions in "
484+
"%select{'@inlinable' functions|default arguments}0 must be marked "
485+
"'@differentiable' or have a public '@derivative'"
486+
"%select{|; this is not possible with a closure, make a top-level "
487+
"function instead}1", (unsigned, bool))
488+
NOTE(autodiff_function_noderivative_parameter_not_differentiable,none,
489+
"cannot differentiate with respect to a '@noDerivative' parameter", ())
490+
NOTE(autodiff_function_assoc_func_unmet_requirements,none,
491+
"function call is not differentiable because generic requirements are not "
492+
"met: '%0'", (/*requirements*/ StringRef))
493+
NOTE(autodiff_nondifferentiable_argument,none,
494+
"cannot differentiate through a non-differentiable argument; do you want "
495+
"to use 'withoutDerivative(at:)'?", ())
496+
NOTE(autodiff_nondifferentiable_result,none,
497+
"cannot differentiate through a non-differentiable result; do you want to "
498+
"use 'withoutDerivative(at:)'?", ())
499+
NOTE(autodiff_protocol_member_not_differentiable,none,
500+
"member is not differentiable because the corresponding protocol "
501+
"requirement is not '@differentiable'", ())
502+
NOTE(autodiff_class_member_not_differentiable,none,
503+
"member is not differentiable because the corresponding class member "
504+
"is not '@differentiable'", ())
505+
NOTE(autodiff_member_subset_indices_not_differentiable,none,
506+
"member is differentiable only with respect to a smaller subset of "
507+
"arguments", ())
508+
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`
509+
// functions.
510+
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
511+
"cannot convert a direct method reference to a '@differentiable' "
512+
"function; use an explicit closure instead", ())
477513

478514
ERROR(non_physical_addressof,none,
479515
"addressof only works with purely physical lvalues; "

include/swift/SIL/SILInstruction.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,6 +2056,25 @@ class ApplyInstBase<Impl, Base, false> : public Base {
20562056
/// does it have the given semantics?
20572057
bool doesApplyCalleeHaveSemantics(SILValue callee, StringRef semantics);
20582058

2059+
/// Predicate used to filter InoutArgumentRange.
2060+
struct OperandToInoutArgument {
2061+
ArrayRef<SILParameterInfo> paramInfos;
2062+
OperandValueArrayRef arguments;
2063+
OperandToInoutArgument(ArrayRef<SILParameterInfo> paramInfos,
2064+
OperandValueArrayRef arguments)
2065+
: paramInfos(paramInfos), arguments(arguments) {
2066+
assert(paramInfos.size() == arguments.size());
2067+
}
2068+
Optional<SILValue> operator()(size_t i) const {
2069+
if (paramInfos[i].isIndirectMutating())
2070+
return arguments[i];
2071+
return None;
2072+
}
2073+
};
2074+
2075+
using InoutArgumentRange =
2076+
OptionalTransformRange<IntRange<size_t>, OperandToInoutArgument>;
2077+
20592078
/// The partial specialization of ApplyInstBase for full applications.
20602079
/// Adds some methods relating to 'self' and to result types that don't
20612080
/// make sense for partial applications.
@@ -2068,6 +2087,9 @@ class ApplyInstBase<Impl, Base, true>
20682087
ApplyInstBase(As &&...args)
20692088
: ApplyInstBase<Impl, Base, false>(std::forward<As>(args)...) {}
20702089

2090+
private:
2091+
const Impl &asImpl() const { return static_cast<const Impl &>(*this); }
2092+
20712093
public:
20722094
using super::getCallee;
20732095
using super::getSubstCalleeType;
@@ -2152,6 +2174,16 @@ class ApplyInstBase<Impl, Base, true>
21522174
return getArguments().slice(getNumIndirectResults());
21532175
}
21542176

2177+
/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
2178+
/// instruction.
2179+
InoutArgumentRange getInoutArguments() const {
2180+
auto &impl = asImpl();
2181+
return InoutArgumentRange(
2182+
indices(getArgumentsWithoutIndirectResults()),
2183+
OperandToInoutArgument(impl.getSubstCalleeConv().getParameters(),
2184+
impl.getArgumentsWithoutIndirectResults()));
2185+
}
2186+
21552187
bool hasSemantics(StringRef semanticsString) const {
21562188
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
21572189
}

include/swift/SIL/SILType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ class SILType {
552552
bool isLoweringOf(TypeExpansionContext context, SILModule &M,
553553
CanType formalType);
554554

555+
/// Returns true if this SILType is a differentiable type.
556+
bool isDifferentiable(SILModule &M) const;
557+
555558
/// Returns the hash code for the SILType.
556559
llvm::hash_code getHashCode() const {
557560
return llvm::hash_combine(*this);

include/swift/SILOptimizer/Utils/Differentiation/Common.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "swift/SIL/SILDifferentiabilityWitness.h"
2121
#include "swift/SIL/SILFunction.h"
2222
#include "swift/SIL/SILModule.h"
23+
#include "swift/SIL/TypeSubstCloner.h"
2324

2425
namespace swift {
2526

@@ -33,6 +34,12 @@ namespace autodiff {
3334
/// This is being used to print short debug messages within the AD pass.
3435
raw_ostream &getADDebugStream();
3536

37+
/// Given a function call site, gathers all of its actual results (both direct
38+
/// and indirect) in an order defined by its result type.
39+
void collectAllActualResultsInTypeOrder(
40+
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
41+
SmallVectorImpl<SILValue> &results);
42+
3643
/// Returns the underlying instruction for the given SILValue, if it exists,
3744
/// peering through function conversion instructions.
3845
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
@@ -51,6 +58,70 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
5158
return nullptr;
5259
}
5360

61+
/// Given a range of elements, joins these into a single value. If there's
62+
/// exactly one element, returns that element. Otherwise, creates a tuple using
63+
/// a `tuple` instruction.
64+
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
65+
SILLocation loc);
66+
67+
/// Given a value, extracts all elements to `results` from this value if it has
68+
/// a tuple type. Otherwise, add this value directly to `results`.
69+
void extractAllElements(SILValue value, SILBuilder &builder,
70+
SmallVectorImpl<SILValue> &results);
71+
72+
/// Emit a zero value into the given buffer access by calling
73+
/// `AdditiveArithmetic.zero`. The given type must conform to
74+
/// `AdditiveArithmetic`.
75+
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
76+
SILValue bufferAccess, SILLocation loc);
77+
78+
//===----------------------------------------------------------------------===//
79+
// Utilities for looking up derivatives of functions
80+
//===----------------------------------------------------------------------===//
81+
82+
/// Returns a differentiability witness (definition or declaration) exactly
83+
/// matching the specified indices. If none are found in the given `module`,
84+
/// returns `nullptr`.
85+
///
86+
/// \param parameterIndices must be lowered to SIL.
87+
/// \param resultIndices must be lowered to SIL.
88+
SILDifferentiabilityWitness *
89+
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
90+
IndexSubset *parameterIndices,
91+
IndexSubset *resultIndices);
92+
93+
/// Finds the derivative configuration (from `@differentiable` and
94+
/// `@derivative` attributes) for `original` whose parameter indices are a
95+
/// minimal superset of the specified AST parameter indices. Returns `None` if
96+
/// no such configuration is found.
97+
///
98+
/// \param parameterIndices must be lowered to SIL.
99+
/// \param minimalASTParameterIndices is an output parameter that is set to the
100+
/// AST indices of the minimal configuration, or to `nullptr` if no such
101+
/// configuration exists.
102+
Optional<AutoDiffConfig>
103+
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
104+
IndexSubset *parameterIndices,
105+
IndexSubset *&minimalASTParameterIndices);
106+
107+
/// Returns a differentiability witness for `original` whose parameter indices
108+
/// are a minimal superset of the specified parameter indices and whose result
109+
/// indices match the given result indices, out of all
110+
/// differentiability witnesses that come from AST "@differentiable" or
111+
/// "@differentiating" attributes.
112+
///
113+
/// This function never creates new differentiability witness definitions.
114+
/// However, this function may create new differentiability witness declarations
115+
/// referring to definitions in other modules when these witnesses have not yet
116+
/// been declared in the current module.
117+
///
118+
/// \param module is the SILModule in which to get or create the witnesses.
119+
/// \param parameterIndices must be lowered to SIL.
120+
/// \param resultIndices must be lowered to SIL.
121+
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
122+
SILModule &module, SILFunction *original, IndexSubset *parameterIndices,
123+
IndexSubset *resultIndices);
124+
54125
} // end namespace autodiff
55126

56127
/// Creates arguments in the entry block based on the function type.
@@ -85,6 +156,34 @@ inline void createEntryArguments(SILFunction *f) {
85156
}
86157
}
87158

159+
/// Cloner that remaps types using the target function's generic environment.
160+
class BasicTypeSubstCloner final
161+
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {
162+
163+
static SubstitutionMap getSubstitutionMap(SILFunction *target) {
164+
if (auto *targetGenEnv = target->getGenericEnvironment())
165+
return targetGenEnv->getForwardingSubstitutionMap();
166+
return SubstitutionMap();
167+
}
168+
169+
public:
170+
explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target)
171+
: TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {}
172+
173+
void postProcess(SILInstruction *orig, SILInstruction *cloned) {
174+
SILClonerWithScopes::postProcess(orig, cloned);
175+
}
176+
177+
void cloneFunction() {
178+
auto &newFunction = Builder.getFunction();
179+
auto *entry = newFunction.createBasicBlock();
180+
createEntryArguments(&newFunction);
181+
SmallVector<SILValue, 8> entryArguments(newFunction.getArguments().begin(),
182+
newFunction.getArguments().end());
183+
cloneFunctionBody(&Original, entry, entryArguments);
184+
}
185+
};
186+
88187
} // end namespace swift
89188

90189
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_COMMON_H
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===--- Thunk.h - Automatic differentiation thunks -----------*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// Automatic differentiation thunk generation utilities.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H
18+
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H
19+
20+
#include "swift/AST/AutoDiff.h"
21+
#include "swift/Basic/LLVM.h"
22+
#include "swift/SIL/SILBuilder.h"
23+
24+
namespace swift {
25+
26+
class SILOptFunctionBuilder;
27+
class SILModule;
28+
class SILLocation;
29+
class SILValue;
30+
class OpenedArchetypeType;
31+
class GenericEnvironment;
32+
class SubstitutionMap;
33+
class ArchetypeType;
34+
35+
//===----------------------------------------------------------------------===//
36+
// Helpers
37+
//===----------------------------------------------------------------------===//
38+
39+
namespace autodiff {
40+
41+
//===----------------------------------------------------------------------===//
42+
// Thunk helpers
43+
//===----------------------------------------------------------------------===//
44+
// These helpers are copied/adapted from SILGen. They should be refactored and
45+
// moved to a shared location.
46+
//===----------------------------------------------------------------------===//
47+
48+
/// The thunk kinds used in the differentiation transform.
49+
enum class DifferentiationThunkKind {
50+
/// A reabstraction thunk.
51+
///
52+
/// Reabstraction thunks transform a function-typed value to another one with
53+
/// different parameter/result abstraction patterns. This is identical to the
54+
/// thunks generated by SILGen.
55+
Reabstraction,
56+
57+
/// An index subset thunk.
58+
///
59+
/// An index subset thunk is used transform JVP/VJPs into a version that is
60+
/// "wrt" fewer differentiation parameters.
61+
/// - Differentials of thunked JVPs use zero for non-requested differentiation
62+
/// parameters.
63+
/// - Pullbacks of thunked VJPs discard results for non-requested
64+
/// differentiation parameters.
65+
IndexSubset
66+
};
67+
68+
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
69+
OpenedArchetypeType *openedExistential,
70+
GenericEnvironment *&genericEnv,
71+
SubstitutionMap &contextSubs,
72+
SubstitutionMap &interfaceSubs,
73+
ArchetypeType *&newArchetype);
74+
75+
/// Build the type of a function transformation thunk.
76+
CanSILFunctionType buildThunkType(SILFunction *fn,
77+
CanSILFunctionType &sourceType,
78+
CanSILFunctionType &expectedType,
79+
GenericEnvironment *&genericEnv,
80+
SubstitutionMap &interfaceSubs,
81+
bool withoutActuallyEscaping,
82+
DifferentiationThunkKind thunkKind);
83+
84+
/// Get or create a derivative function parameter index subset thunk from
85+
/// `actualIndices` to `desiredIndices` for the given associated function
86+
/// value and original function operand. Returns a pair of the parameter
87+
/// index subset thunk and its interface substitution map (used to partially
88+
/// apply the thunk).
89+
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
90+
/// map returned by the derivative function.
91+
std::pair<SILFunction *, SubstitutionMap>
92+
getOrCreateSubsetParametersThunkForDerivativeFunction(
93+
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
94+
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
95+
SILAutoDiffIndices actualIndices);
96+
97+
/// Get or create a derivative function parameter index subset thunk from
98+
/// `actualIndices` to `desiredIndices` for the given associated function
99+
/// value and original function operand. Returns a pair of the parameter
100+
/// index subset thunk and its interface substitution map (used to partially
101+
/// apply the thunk).
102+
std::pair<SILFunction *, SubstitutionMap>
103+
getOrCreateSubsetParametersThunkForLinearMap(
104+
SILOptFunctionBuilder &fb, SILFunction *assocFn,
105+
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
106+
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
107+
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
108+
109+
} // end namespace autodiff
110+
111+
} // end namespace swift
112+
113+
#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_THUNK_H

lib/SIL/IR/SILType.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,12 @@ bool SILType::isLoweringOf(TypeExpansionContext context, SILModule &Mod,
619619
return loweredType.getASTType() == formalType;
620620
}
621621

622+
bool SILType::isDifferentiable(SILModule &M) const {
623+
return getASTType()
624+
->getAutoDiffTangentSpace(LookUpConformanceInModule(M.getSwiftModule()))
625+
.hasValue();
626+
}
627+
622628
Type
623629
TypeBase::replaceSubstitutedSILFunctionTypesWithUnsubstituted(SILModule &M) const {
624630
return Type(const_cast<TypeBase*>(this)).transform([&](Type t) -> Type {

0 commit comments

Comments
 (0)