Skip to content

Commit 740b63e

Browse files
authored
[AutoDiff] Introduce 'linear_function' and 'linear_function_extract' instructions. (#27637)
Introduce `linear_function` and `linear_function_extract` instructions, which are used for creating and destructing `@differentiable(linear)` functions. ### `linear_function` instruction Bundles a function with its transpose function into a `@differentiable(linear)` function. ``` sil-instruction ::= 'linear_function' sil-linear-function-parameter-indices? sil-value ':' sil-type sil-linear-function-transpose-function-clause? sil-linear-function-parameter-indices ::= '[' 'parameters' [0-9]+ (' ' [0-9]+)* ']' sil-linear-transpose-function-clause ::= with_transpose sil-value ':' sil-type ``` ### `linear_function_extract` instruction Extracts the original function or a transpose function from the given ``@differentiable(linear)`` function. It must be provided with an extractee: ``[original]`` or ``[transpose]``. ``` sil-instruction ::= 'linear_function_extract' sil-linear-function-extractee sil-value ':' sil-type sil-linear-function-extractee ::= '[' sil-linear-function-extractee ']' sil-linear-function-extractee-name ::= 'original' | 'transpose' linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T ``` Resolves [TF-907](https://bugs.swift.org/browse/TF-907).
1 parent 4a4fe83 commit 740b63e

35 files changed

+695
-107
lines changed

docs/SIL.rst

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5607,18 +5607,13 @@ differentiable_function
56075607

56085608
sil-instruction ::= 'differentiable_function'
56095609
sil-differentiable-function-parameter-indices?
5610-
sil-differentiable-function-order?
56115610
sil-value ':' sil-type
56125611
sil-differentiable-function-derivative-functions-clause?
56135612
56145613
sil-differentiable-function-parameter-indices ::=
5615-
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
5616-
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
5614+
'[' 'wrt' [0-9]+ (' ' [0-9]+)* ']'
56175615
sil-differentiable-derivative-functions-clause ::=
5618-
'with' sil-differentiable-derivative-function-list
5619-
(',' sil-differentiable-derivative-function-list)*
5620-
sil-differentiable-function-derivative-function-list ::=
5621-
'{' sil-value ',' sil-value '}'
5616+
'with' '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
56225617

56235618
differentiable_function [wrt 0] %0 : $(T) -> T \
56245619
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
@@ -5640,20 +5635,50 @@ In raw SIL, it is optional to provide a derivative function ``with`` clause.
56405635
In canonical SIL, a ``with`` clause is mandatory.
56415636

56425637

5638+
linear_function
5639+
```````````````
5640+
5641+
::
5642+
5643+
sil-instruction ::= 'linear_function'
5644+
sil-linear-function-parameter-indices?
5645+
sil-value ':' sil-type
5646+
sil-linear-function-transpose-function-clause?
5647+
5648+
sil-linear-function-parameter-indices ::=
5649+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5650+
sil-linear-transpose-function-clause ::=
5651+
with_transpose sil-value ':' sil-type
5652+
5653+
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T
5654+
5655+
Bundles a function with its transpose function into a
5656+
``@differentiable(linear)`` function.
5657+
5658+
``[parameters ...]`` specifies parameter indices that the original function is
5659+
linear with respect to. When not specified, it defaults to all parameters.
5660+
5661+
A ``with_transpose`` clause specifies the transpose function associated
5662+
with the original function. When a ``with_transpose`` clause is not specified,
5663+
the mandatory differentiation transform will add a ``with_transpose`` clause to
5664+
the instruction.
5665+
5666+
In raw SIL, it is optional to provide a transpose function ``with`` clause.
5667+
In canonical SIL, a ``with`` clause is mandatory.
5668+
5669+
56435670
differentiable_function_extract
56445671
```````````````````````````````
56455672

56465673
::
56475674

56485675
sil-instruction ::= 'differentiable_function_extract'
56495676
sil-differentiable-function-extractee
5650-
sil-differentiable-function-order?
56515677
sil-value ':' sil-type
56525678

56535679
sil-differentiable-function-extractee ::=
56545680
'[' sil-differentiable-function-extractee ']'
56555681
sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
5656-
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'
56575682

56585683
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
56595684
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
@@ -5664,6 +5689,27 @@ Extracts the original function or a derivative function from the given
56645689
``[original]``, ``[jvp]`` or ``[vjp]``.
56655690

56665691

5692+
linear_function_extract
5693+
```````````````````````
5694+
5695+
::
5696+
5697+
sil-instruction ::= 'linear_function_extract'
5698+
sil-linear-function-extractee
5699+
sil-value ':' sil-type
5700+
5701+
sil-linear-function-extractee ::=
5702+
'[' sil-linear-function-extractee ']'
5703+
sil-linear-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
5704+
5705+
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
5706+
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
5707+
5708+
Extracts the original function or a transpose function from the given
5709+
``@differentiable(linear)`` function. It must be provided with an extractee:
5710+
``[original]`` or ``[transpose]``.
5711+
5712+
56675713
Assertion configuration
56685714
~~~~~~~~~~~~~~~~~~~~~~~
56695715

include/swift/AST/Attr.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ class TypeAttributes {
8989
bool isValid() const { return AtLoc.isValid(); }
9090

9191
// SWIFT_ENABLE_TENSORFLOW
92-
bool isLinear() const { return linear; }
93-
92+
bool isLinear() const {
93+
assert(!linear || (linear && has(TAK_differentiable)) &&
94+
"Linear shouldn't have been true if there's no `@differentiable`");
95+
return linear;
96+
}
97+
9498
void clearAttribute(TypeAttrKind A) {
9599
AttrLocs[A] = SourceLoc();
96100
}

include/swift/AST/AutoDiff.h

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ class SILFunctionType;
3939
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
4040
enum class SILLinkage : uint8_t;
4141

42-
enum class DifferentiabilityKind : uint8_t {
43-
NonDifferentiable = 0b00,
44-
Normal = 0b01,
45-
Linear = 0b11
42+
enum class DifferentiabilityKind: uint8_t {
43+
NonDifferentiable = 0,
44+
Normal = 1,
45+
Linear = 2
4646
};
4747

4848
// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
@@ -52,9 +52,19 @@ enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
5252
VJP = 2
5353
};
5454

55-
enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
56-
Original = 0,
57-
Transpose = 1
55+
struct LinearDifferentiableFunctionTypeComponent {
56+
enum innerty : unsigned {
57+
Original = 0,
58+
Transpose = 1,
59+
} rawValue;
60+
61+
LinearDifferentiableFunctionTypeComponent() = default;
62+
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
63+
: rawValue(rawValue) {}
64+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
65+
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
66+
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
67+
operator innerty() const { return rawValue; }
5868
};
5969

6070
class ParsedAutoDiffParameter {

include/swift/AST/DiagnosticsParse.def

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,8 +1614,12 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
16141614
"expected '}' to start a derivative function list", ())
16151615
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
16161616
"the number of operand lists does not match the order", ())
1617-
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
1618-
"expected a derivative function kind attribute, e.g. '[jvp]'", ())
1617+
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
1618+
"expected an extractee kind attribute, which can be one of '[original]', "
1619+
"'[jvp]', and '[vjp]'", ())
1620+
ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
1621+
"expected an extractee kind attribute, which can be one of '[original]' "
1622+
"and '[transpose]'", ())
16191623
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
16201624
"expected an operand of a function type", ())
16211625

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,8 +2913,6 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
29132913
"overriding declaration is missing attribute '%0'", (StringRef))
29142914
NOTE(protocol_witness_missing_differentiable_attr,none,
29152915
"candidate is missing attribute '%0'", (StringRef))
2916-
ERROR(linear_differentiable_type_disabled,none,
2917-
"'@differentiable(linear)' types are not yet supported", ())
29182916

29192917
// @differentiating
29202918
ERROR(differentiating_attr_expected_result_tuple,none,

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4216,7 +4216,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42164216

42174217
// SWIFT_ENABLE_TENSORFLOW
42184218
CanSILFunctionType getWithDifferentiability(
4219-
IndexSubset *parameterIndices);
4219+
DifferentiabilityKind kind, IndexSubset *parameterIndices);
42204220

42214221
CanSILFunctionType getWithoutDifferentiability();
42224222

include/swift/SIL/SILBuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,14 @@ class SILBuilder {
518518
getModule(), getSILDebugLocation(Loc), ParameterIndices,
519519
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
520520
}
521+
522+
LinearFunctionInst *createLinearFunction(
523+
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
524+
Optional<SILValue> TransposeFunction) {
525+
return insert(LinearFunctionInst::create(
526+
getModule(), getSILDebugLocation(Loc), ParameterIndices,
527+
OriginalFunction, TransposeFunction, hasOwnership()));
528+
}
521529

522530
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
523531
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
@@ -526,6 +534,13 @@ class SILBuilder {
526534
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
527535
}
528536

537+
LinearFunctionExtractInst *createLinearFunctionExtract(
538+
SILLocation Loc, LinearDifferentiableFunctionTypeComponent Extractee,
539+
SILValue TheFunction) {
540+
return insert(new (getModule()) LinearFunctionExtractInst(
541+
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
542+
}
543+
529544
DifferentiableFunctionExtractInst *
530545
createDifferentiableFunctionExtractOriginal(SILLocation Loc,
531546
SILValue TheFunction) {

include/swift/SIL/SILCloner.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,18 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
980980
getOpValue(Inst->getOriginalFunction()), derivativeFns));
981981
}
982982

983+
template<typename ImplClass>
984+
void SILCloner<ImplClass>::visitLinearFunctionInst(LinearFunctionInst *Inst) {
985+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
986+
auto transpose = Inst->getOptionalTransposeFunction();
987+
if (transpose)
988+
transpose = getOpValue(*transpose);
989+
recordClonedInstruction(
990+
Inst, getBuilder().createLinearFunction(
991+
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
992+
getOpValue(Inst->getOriginalFunction()), transpose));
993+
}
994+
983995
template<typename ImplClass>
984996
void SILCloner<ImplClass>::
985997
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst) {
@@ -989,6 +1001,16 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst)
9891001
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
9901002
getOpValue(Inst->getFunctionOperand())));
9911003
}
1004+
1005+
template<typename ImplClass>
1006+
void SILCloner<ImplClass>::
1007+
visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
1008+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
1009+
recordClonedInstruction(
1010+
Inst, getBuilder().createLinearFunctionExtract(
1011+
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
1012+
getOpValue(Inst->getFunctionOperand())));
1013+
}
9921014
// SWIFT_ENABLE_TENSORFLOW END
9931015

9941016
template<typename ImplClass>

include/swift/SIL/SILInstruction.h

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7845,7 +7845,7 @@ class TryApplyInst final
78457845
};
78467846

78477847
// SWIFT_ENABLE_TENSORFLOW
7848-
/// `differentiable_function` - given a function and differentiation indices and
7848+
/// `differentiable_function` - given a function, differentiation indices and
78497849
/// its derivative functions, create an `@differentiable` function that
78507850
/// represents a bundle of these functions and configurations.
78517851
class DifferentiableFunctionInst final :
@@ -7865,14 +7865,14 @@ class DifferentiableFunctionInst final :
78657865
bool HasOwnership);
78667866

78677867
static SILType getDifferentiableFunctionType(
7868-
SILValue Original, IndexSubset *ParameterIndices);
7868+
SILValue OriginalFunction, IndexSubset *ParameterIndices);
78697869

78707870
static ValueOwnershipKind getMergedOwnershipKind(
7871-
SILValue Original, ArrayRef<SILValue> DerivativeFunctions);
7871+
SILValue OriginalFunction, ArrayRef<SILValue> DerivativeFunctions);
78727872

78737873
public:
78747874
static DifferentiableFunctionInst *create(
7875-
SILModule &Module, SILDebugLocation DebugLoc,
7875+
SILModule &Module, SILDebugLocation Loc,
78767876
IndexSubset *ParameterIndices, SILValue OriginalFunction,
78777877
Optional<std::pair<SILValue, SILValue>> VJPAndJVPFunctions,
78787878
bool HasOwnership);
@@ -7920,6 +7920,46 @@ class DifferentiableFunctionInst final :
79207920
}
79217921
};
79227922

7923+
/// `linear_function` - given a function, differentiation parameter indices,
7924+
/// result indices, and its derivative functions, create an `@differentiable`
7925+
/// function that represents a bundle of these functions and configurations.
7926+
class LinearFunctionInst final :
7927+
public InstructionBaseWithTrailingOperands<
7928+
SILInstructionKind::LinearFunctionInst,
7929+
LinearFunctionInst, OwnershipForwardingSingleValueInst> {
7930+
private:
7931+
friend SILBuilder;
7932+
/// Parameters to differentiate with respect to.
7933+
IndexSubset *ParameterIndices;
7934+
/// Indicates whether a transpose function exists.
7935+
bool HasTransposeFunction;
7936+
7937+
static SILType getLinearFunctionType(
7938+
SILValue OriginalFunction, IndexSubset *ParameterIndices);
7939+
7940+
public:
7941+
LinearFunctionInst(SILDebugLocation Loc, IndexSubset *ParameterIndices,
7942+
SILValue OriginalFunction,
7943+
Optional<SILValue> TransposeFunction, bool HasOwnership);
7944+
7945+
static LinearFunctionInst *create(SILModule &Module, SILDebugLocation Loc,
7946+
IndexSubset *ParameterIndices,
7947+
SILValue OriginalFunction,
7948+
Optional<SILValue> TransposeFunction,
7949+
bool HasOwnership);
7950+
7951+
IndexSubset *getParameterIndices() const { return ParameterIndices; }
7952+
bool hasTransposeFunction() const { return HasTransposeFunction; }
7953+
SILValue getOriginalFunction() const { return getOperand(0); }
7954+
Optional<SILValue> getOptionalTransposeFunction() const {
7955+
return HasTransposeFunction ? Optional<SILValue>(getOperand(1)) : None;
7956+
}
7957+
SILValue getTransposeFunction() const {
7958+
assert(HasTransposeFunction);
7959+
return getOperand(1);
7960+
}
7961+
};
7962+
79237963
/// `differentiable_function_extract` - given an `@differentiable` function
79247964
/// representing a bundle of the original function and derivative functions,
79257965
/// extract the specified function.
@@ -7974,6 +8014,41 @@ class DifferentiableFunctionExtractInst
79748014

79758015
typedef DifferentiableFunctionExtractInst::Extractee
79768016
DifferentiableFunctionExtractee;
8017+
8018+
/// `linear_function_extract` - given an `@differentiable(linear)` function
8019+
/// representing a bundle of the original function and the transpose function,
8020+
/// extract the specified function.
8021+
class LinearFunctionExtractInst
8022+
: public InstructionBase<
8023+
SILInstructionKind::LinearFunctionExtractInst,
8024+
SingleValueInstruction> {
8025+
private:
8026+
/// The extractee.
8027+
LinearDifferentiableFunctionTypeComponent extractee;
8028+
/// The list containing the `@differentiable(linear)` function operand.
8029+
FixedOperandList<1> operands;
8030+
8031+
static SILType
8032+
getExtracteeType(SILValue function,
8033+
LinearDifferentiableFunctionTypeComponent extractee,
8034+
SILModule &module);
8035+
8036+
public:
8037+
explicit LinearFunctionExtractInst(
8038+
SILModule &module, SILDebugLocation debugLoc,
8039+
LinearDifferentiableFunctionTypeComponent extractee,
8040+
SILValue theFunction);
8041+
8042+
LinearDifferentiableFunctionTypeComponent getExtractee() const {
8043+
return extractee;
8044+
}
8045+
8046+
SILValue getFunctionOperand() const { return operands[0].get(); }
8047+
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
8048+
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
8049+
};
8050+
8051+
typedef LinearDifferentiableFunctionTypeComponent LinearFunctionExtractee;
79778052
// SWIFT_ENABLE_TENSORFLOW END
79788053

79798054
// This is defined out of line to work around the fact that this depends on

include/swift/SIL/SILNodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,14 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
692692
// Differentiable programming
693693
SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function,
694694
SingleValueInstruction, None, DoesNotRelease)
695+
SINGLE_VALUE_INST(LinearFunctionInst, linear_function,
696+
SingleValueInstruction, None, DoesNotRelease)
695697
SINGLE_VALUE_INST(DifferentiableFunctionExtractInst,
696698
differentiable_function_extract,
697699
SingleValueInstruction, None, DoesNotRelease)
700+
SINGLE_VALUE_INST(LinearFunctionExtractInst,
701+
linear_function_extract,
702+
SingleValueInstruction, None, DoesNotRelease)
698703

699704
// Key paths
700705
// TODO: The only "side effect" is potentially retaining the returned key path

0 commit comments

Comments
 (0)