Skip to content

Commit 62f6686

Browse files
authored
Merge pull request #30579 from dan-zheng/autodiff-upstream-sil
[AutoDiff upstream] [SIL] Add differentiable function instructions.
2 parents de46690 + 7e10a50 commit 62f6686

27 files changed

+1204
-5
lines changed

docs/SIL.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5774,6 +5774,67 @@ The rules on generic substitutions are identical to those of ``apply``.
57745774
Differentiable Programming
57755775
~~~~~~~~~~~~~~~~~~~~~~~~~~
57765776

5777+
differentiable_function
5778+
```````````````````````
5779+
::
5780+
5781+
sil-instruction ::= 'differentiable_function'
5782+
sil-differentiable-function-parameter-indices
5783+
sil-value ':' sil-type
5784+
sil-differentiable-function-derivative-functions-clause?
5785+
5786+
sil-differentiable-function-parameter-indices ::=
5787+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5788+
sil-differentiable-derivative-functions-clause ::=
5789+
'with_derivative'
5790+
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
5791+
5792+
differentiable_function [parameters 0] %0 : $(T) -> T \
5793+
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
5794+
5795+
Creates a ``@differentiable`` function from an original function operand and
5796+
derivative function operands (optional). There are two derivative function
5797+
kinds: a Jacobian-vector products (JVP) function and a vector-Jacobian products
5798+
(VJP) function.
5799+
5800+
``[parameters ...]`` specifies parameter indices that the original function is
5801+
differentiable with respect to.
5802+
5803+
The ``with_derivative`` clause specifies the derivative function operands
5804+
associated with the original function.
5805+
5806+
The differentiation transformation canonicalizes all `differentiable_function`
5807+
instructions, generating derivative functions if necessary to fill in derivative
5808+
function operands.
5809+
5810+
In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the
5811+
``with_derivative`` clause is mandatory.
5812+
5813+
5814+
differentiable_function_extract
5815+
```````````````````````````````
5816+
::
5817+
5818+
sil-instruction ::= 'differentiable_function_extract'
5819+
'[' sil-differentiable-function-extractee ']'
5820+
sil-value ':' sil-type
5821+
('as' sil-type)?
5822+
5823+
sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'
5824+
5825+
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5826+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5827+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
5828+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \
5829+
as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector)
5830+
5831+
Extracts the original function or a derivative function from the given
5832+
``@differentiable`` function. The extractee is one of the following:
5833+
``[original]``, ``[jvp]``, or ``[vjp]``.
5834+
5835+
In lowered SIL, an explicit extractee type may be provided. This is currently
5836+
used by the LoadableByAddress transformation, which rewrites function types.
5837+
57775838
differentiability_witness_function
57785839
``````````````````````````````````
57795840
::

include/swift/AST/AutoDiff.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ struct AutoDiffDerivativeFunctionKind {
7575
}
7676
};
7777

78+
/// A component of a SIL `@differentiable` function-typed value.
79+
struct NormalDifferentiableFunctionTypeComponent {
80+
enum innerty : unsigned { Original = 0, JVP = 1, VJP = 2 } rawValue;
81+
82+
NormalDifferentiableFunctionTypeComponent() = default;
83+
NormalDifferentiableFunctionTypeComponent(innerty rawValue)
84+
: rawValue(rawValue) {}
85+
NormalDifferentiableFunctionTypeComponent(
86+
AutoDiffDerivativeFunctionKind kind);
87+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
88+
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
89+
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
90+
operator innerty() const { return rawValue; }
91+
92+
/// Returns the derivative function kind, if the component is a derivative
93+
/// function.
94+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
95+
};
96+
97+
/// A component of a SIL `@differentiable(linear)` function-typed value.
98+
struct LinearDifferentiableFunctionTypeComponent {
99+
enum innerty : unsigned {
100+
Original = 0,
101+
Transpose = 1,
102+
} rawValue;
103+
104+
LinearDifferentiableFunctionTypeComponent() = default;
105+
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
106+
: rawValue(rawValue) {}
107+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
108+
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
109+
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
110+
operator innerty() const { return rawValue; }
111+
};
112+
78113
/// A derivative function configuration, uniqued in `ASTContext`.
79114
/// Identifies a specific derivative function given an original function.
80115
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {

include/swift/AST/DiagnosticsParse.def

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,6 +1628,17 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
16281628
"expected the index of a parameter to differentiate with respect to", ())
16291629
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
16301630
"expected the index of a result to differentiate from", ())
1631+
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
1632+
"expected '{' to start a derivative function list", ())
1633+
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,
1634+
"expected ',' between operands in a derivative function list", ())
1635+
ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
1636+
"expected '}' to start a derivative function list", ())
1637+
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
1638+
"expected an extractee kind attribute, which can be one of '[original]', "
1639+
"'[jvp]', and '[vjp]'", ())
1640+
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
1641+
"expected an operand of a function type", ())
16311642
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
16321643
"expected a differentiability witness kind, which can be one of '[jvp]', "
16331644
"'[vjp]', or '[transpose]'", ())

include/swift/SIL/SILBuilder.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,23 @@ class SILBuilder {
21642164
// Differentiable programming instructions
21652165
//===--------------------------------------------------------------------===//
21662166

2167+
DifferentiableFunctionInst *createDifferentiableFunction(
2168+
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
2169+
Optional<std::pair<SILValue, SILValue>> JVPAndVJPFunctions = None) {
2170+
return insert(DifferentiableFunctionInst::create(
2171+
getModule(), getSILDebugLocation(Loc), ParameterIndices,
2172+
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
2173+
}
2174+
2175+
/// Note: explicit extractee type may be specified only in lowered SIL.
2176+
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
2177+
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
2178+
SILValue Function, Optional<SILType> ExtracteeType = None) {
2179+
return insert(new (getModule()) DifferentiableFunctionExtractInst(
2180+
getModule(), getSILDebugLocation(Loc), Extractee, Function,
2181+
ExtracteeType));
2182+
}
2183+
21672184
/// Note: explicit function type may be specified only in lowered SIL.
21682185
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
21692186
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,

include/swift/SIL/SILCloner.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,6 +2827,33 @@ void SILCloner<ImplClass>::visitKeyPathInst(KeyPathInst *Inst) {
28272827
opValues, getOpType(Inst->getType())));
28282828
}
28292829

2830+
template <typename ImplClass>
2831+
void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
2832+
DifferentiableFunctionInst *Inst) {
2833+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2834+
Optional<std::pair<SILValue, SILValue>> derivativeFns = None;
2835+
if (Inst->hasDerivativeFunctions())
2836+
derivativeFns = std::make_pair(getOpValue(Inst->getJVPFunction()),
2837+
getOpValue(Inst->getVJPFunction()));
2838+
recordClonedInstruction(
2839+
Inst, getBuilder().createDifferentiableFunction(
2840+
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
2841+
getOpValue(Inst->getOriginalFunction()), derivativeFns));
2842+
}
2843+
2844+
template <typename ImplClass>
2845+
void SILCloner<ImplClass>::visitDifferentiableFunctionExtractInst(
2846+
DifferentiableFunctionExtractInst *Inst) {
2847+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2848+
Optional<SILType> explicitExtracteeType = None;
2849+
if (Inst->hasExplicitExtracteeType())
2850+
explicitExtracteeType = Inst->getType();
2851+
recordClonedInstruction(
2852+
Inst, getBuilder().createDifferentiableFunctionExtract(
2853+
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
2854+
getOpValue(Inst->getOperand()), explicitExtracteeType));
2855+
}
2856+
28302857
template <typename ImplClass>
28312858
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
28322859
DifferentiabilityWitnessFunctionInst *Inst) {

include/swift/SIL/SILInstruction.h

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7976,6 +7976,126 @@ class TryApplyInst final
79767976
const GenericSpecializationInformation *SpecializationInfo);
79777977
};
79787978

7979+
/// DifferentiableFunctionInst - creates a `@differentiable` function-typed
7980+
/// value from an original function operand and derivative function operands
7981+
/// (optional). The differentiation transform canonicalizes
7982+
/// `differentiable_function` instructions, filling in derivative function
7983+
/// operands if missing.
7984+
class DifferentiableFunctionInst final
7985+
: public InstructionBaseWithTrailingOperands<
7986+
SILInstructionKind::DifferentiableFunctionInst,
7987+
DifferentiableFunctionInst, OwnershipForwardingSingleValueInst> {
7988+
private:
7989+
friend SILBuilder;
7990+
/// Differentiability parameter indices.
7991+
IndexSubset *ParameterIndices;
7992+
/// Indicates whether derivative function operands (JVP/VJP) exist.
7993+
bool HasDerivativeFunctions;
7994+
7995+
DifferentiableFunctionInst(SILDebugLocation DebugLoc,
7996+
IndexSubset *ParameterIndices,
7997+
SILValue OriginalFunction,
7998+
ArrayRef<SILValue> DerivativeFunctions,
7999+
bool HasOwnership);
8000+
8001+
static SILType getDifferentiableFunctionType(SILValue OriginalFunction,
8002+
IndexSubset *ParameterIndices);
8003+
8004+
static ValueOwnershipKind
8005+
getMergedOwnershipKind(SILValue OriginalFunction,
8006+
ArrayRef<SILValue> DerivativeFunctions);
8007+
8008+
public:
8009+
static DifferentiableFunctionInst *
8010+
create(SILModule &Module, SILDebugLocation Loc, IndexSubset *ParameterIndices,
8011+
SILValue OriginalFunction,
8012+
Optional<std::pair<SILValue, SILValue>> VJPAndJVPFunctions,
8013+
bool HasOwnership);
8014+
8015+
/// Returns the original function operand.
8016+
SILValue getOriginalFunction() const { return getOperand(0); }
8017+
8018+
/// Returns differentiability parameter indices.
8019+
IndexSubset *getParameterIndices() const { return ParameterIndices; }
8020+
8021+
/// Returns true if derivative functions (JVP/VJP) exist.
8022+
bool hasDerivativeFunctions() const { return HasDerivativeFunctions; }
8023+
8024+
/// Returns the derivative function operands if they exist.
8025+
/// Otherwise, return `None`.
8026+
Optional<std::pair<SILValue, SILValue>>
8027+
getOptionalDerivativeFunctionPair() const {
8028+
if (!HasDerivativeFunctions)
8029+
return None;
8030+
return std::make_pair(getOperand(1), getOperand(2));
8031+
}
8032+
8033+
ArrayRef<Operand> getDerivativeFunctionArray() const {
8034+
return getAllOperands().drop_front();
8035+
}
8036+
8037+
/// Returns the JVP function operand.
8038+
SILValue getJVPFunction() const {
8039+
assert(HasDerivativeFunctions);
8040+
return getOperand(1);
8041+
}
8042+
8043+
/// Returns the VJP function operand.
8044+
SILValue getVJPFunction() const {
8045+
assert(HasDerivativeFunctions);
8046+
return getOperand(2);
8047+
}
8048+
8049+
/// Returns the derivative function operand (JVP or VJP) with the given kind.
8050+
SILValue getDerivativeFunction(AutoDiffDerivativeFunctionKind kind) const {
8051+
switch (kind) {
8052+
case AutoDiffDerivativeFunctionKind::JVP:
8053+
return getJVPFunction();
8054+
case AutoDiffDerivativeFunctionKind::VJP:
8055+
return getVJPFunction();
8056+
}
8057+
}
8058+
};
8059+
8060+
/// DifferentiableFunctionExtractInst - extracts either the original or
8061+
/// derivative function value from a `@differentiable` function.
8062+
class DifferentiableFunctionExtractInst
8063+
: public UnaryInstructionBase<
8064+
SILInstructionKind::DifferentiableFunctionExtractInst,
8065+
SingleValueInstruction> {
8066+
private:
8067+
/// The extractee.
8068+
NormalDifferentiableFunctionTypeComponent Extractee;
8069+
/// True if the instruction has an explicit extractee type.
8070+
bool HasExplicitExtracteeType;
8071+
8072+
static SILType
8073+
getExtracteeType(SILValue function,
8074+
NormalDifferentiableFunctionTypeComponent extractee,
8075+
SILModule &module);
8076+
8077+
public:
8078+
/// Note: explicit extractee type may be specified only in lowered SIL.
8079+
explicit DifferentiableFunctionExtractInst(
8080+
SILModule &module, SILDebugLocation debugLoc,
8081+
NormalDifferentiableFunctionTypeComponent extractee, SILValue function,
8082+
Optional<SILType> extracteeType = None);
8083+
8084+
NormalDifferentiableFunctionTypeComponent getExtractee() const {
8085+
return Extractee;
8086+
}
8087+
8088+
AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const {
8089+
auto kind = Extractee.getAsDerivativeFunctionKind();
8090+
assert(kind);
8091+
return *kind;
8092+
}
8093+
8094+
bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; }
8095+
};
8096+
8097+
/// DifferentiabilityWitnessFunctionInst - Looks up a differentiability witness
8098+
/// function for a given original function.
79798099
class DifferentiabilityWitnessFunctionInst
79808100
: public InstructionBase<
79818101
SILInstructionKind::DifferentiabilityWitnessFunctionInst,

include/swift/SIL/SILNodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,11 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
692692
SingleValueInstruction, None, DoesNotRelease)
693693

694694
// Differentiable programming
695+
SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function,
696+
SingleValueInstruction, None, DoesNotRelease)
697+
SINGLE_VALUE_INST(DifferentiableFunctionExtractInst,
698+
differentiable_function_extract,
699+
SingleValueInstruction, None, DoesNotRelease)
695700
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
696701
differentiability_witness_function,
697702
SingleValueInstruction, None, DoesNotRelease)

lib/AST/AutoDiff.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,50 @@ AutoDiffDerivativeFunctionKind::AutoDiffDerivativeFunctionKind(
2828
rawValue = *result;
2929
}
3030

31+
NormalDifferentiableFunctionTypeComponent::
32+
NormalDifferentiableFunctionTypeComponent(
33+
AutoDiffDerivativeFunctionKind kind) {
34+
switch (kind) {
35+
case AutoDiffDerivativeFunctionKind::JVP:
36+
rawValue = JVP;
37+
return;
38+
case AutoDiffDerivativeFunctionKind::VJP:
39+
rawValue = VJP;
40+
return;
41+
}
42+
}
43+
44+
NormalDifferentiableFunctionTypeComponent::
45+
NormalDifferentiableFunctionTypeComponent(StringRef string) {
46+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
47+
.Case("original", Original)
48+
.Case("jvp", JVP)
49+
.Case("vjp", VJP);
50+
assert(result && "Invalid string");
51+
rawValue = *result;
52+
}
53+
54+
Optional<AutoDiffDerivativeFunctionKind>
55+
NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const {
56+
switch (rawValue) {
57+
case Original:
58+
return None;
59+
case JVP:
60+
return {AutoDiffDerivativeFunctionKind::JVP};
61+
case VJP:
62+
return {AutoDiffDerivativeFunctionKind::VJP};
63+
}
64+
}
65+
66+
LinearDifferentiableFunctionTypeComponent::
67+
LinearDifferentiableFunctionTypeComponent(StringRef string) {
68+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
69+
.Case("original", Original)
70+
.Case("transpose", Transpose);
71+
assert(result && "Invalid string");
72+
rawValue = *result;
73+
}
74+
3175
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
3276
StringRef string) {
3377
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)

lib/IRGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_swift_host_library(swiftIRGen STATIC
1616
GenControl.cpp
1717
GenCoverage.cpp
1818
GenDecl.cpp
19+
GenDiffFunc.cpp
1920
GenDiffWitness.cpp
2021
GenEnum.cpp
2122
GenExistential.cpp

0 commit comments

Comments
 (0)