Skip to content

[AutoDiff] Introduce 'linear_function' and 'linear_function_extract' instructions. #27637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2bb558e
[AutoDiff] Introduce 'linear_function' instruction.
rxwei Oct 12, 2019
496fb9a
Fix `@differentiable(linear)` parsing ambiguity.
rxwei Oct 12, 2019
90016e7
Enable `@differentiable(linear)` function types.
rxwei Oct 12, 2019
a42763e
Merge branch 'tensorflow' of github.com:apple/swift into linear-funct…
rxwei Oct 12, 2019
782aada
Fix test.
rxwei Oct 12, 2019
3256d44
Handle '@differentiable(linear)' in SIL serialization and lowering.
rxwei Oct 12, 2019
f4c0411
Add 'linear_function_extract' instruction.
rxwei Oct 13, 2019
8971a80
Handle '@differentiable(linear)' type lowering.
rxwei Oct 13, 2019
8af911c
Gardening.
rxwei Oct 13, 2019
eb34058
Fix ASTGen for `@differentiable(linear)`.
rxwei Oct 14, 2019
4f1e371
Correct BLOCK_RECORD order in serialization.
rxwei Oct 14, 2019
fe42900
Update tests.
rxwei Oct 14, 2019
18f7f7c
Register SIL abbr codes for new instructions.
rxwei Oct 14, 2019
a1b6f02
ASTGen: Check whether a '@differentiable' attribute argument exists.
rxwei Oct 14, 2019
aa1cef3
Add '@differentiable(linear)' SIL parsing test.
rxwei Oct 14, 2019
43ed80c
Merge branch 'tensorflow' of github.com:apple/swift into linear-funct…
rxwei Oct 14, 2019
a86bd7f
Update module version.
rxwei Oct 14, 2019
91f6f66
Apply changes to tests in #27659 and add IRGen RUN line
rxwei Oct 14, 2019
ac8ab16
Merge branch 'tensorflow' of github.com:apple/swift into linear-funct…
rxwei Oct 14, 2019
cea8ff5
Update test.
rxwei Oct 14, 2019
76721ab
Merge branch 'tensorflow' of github.com:apple/swift into linear-funct…
rxwei Oct 14, 2019
5af2b01
Update to use LinearDifferentiableFunctionTypeComponent.
rxwei Oct 14, 2019
d1b0c55
Add to SIL language reference.
rxwei Oct 15, 2019
ca43c62
Add SIL verification.
rxwei Oct 15, 2019
71d06fd
Fix SIL docs.
rxwei Oct 15, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5607,18 +5607,13 @@ differentiable_function

sil-instruction ::= 'differentiable_function'
sil-differentiable-function-parameter-indices?
sil-differentiable-function-order?
sil-value ':' sil-type
sil-differentiable-function-derivative-functions-clause?

sil-differentiable-function-parameter-indices ::=
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
'[' 'wrt' [0-9]+ (' ' [0-9]+)* ']'
sil-differentiable-derivative-functions-clause ::=
'with' sil-differentiable-derivative-function-list
(',' sil-differentiable-derivative-function-list)*
sil-differentiable-function-derivative-function-list ::=
'{' sil-value ',' sil-value '}'
'with' '{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'

differentiable_function [wrt 0] %0 : $(T) -> T \
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
Expand All @@ -5640,20 +5635,50 @@ In raw SIL, it is optional to provide a derivative function ``with`` clause.
In canonical SIL, a ``with`` clause is mandatory.


linear_function
```````````````

::

sil-instruction ::= 'linear_function'
sil-linear-function-parameter-indices?
sil-value ':' sil-type
sil-linear-function-transpose-function-clause?

sil-linear-function-parameter-indices ::=
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
sil-linear-transpose-function-clause ::=
with_transpose sil-value ':' sil-type

linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T

Bundles a function with its transpose function into a
``@differentiable(linear)`` function.

``[parameters ...]`` specifies parameter indices that the original function is
linear with respect to. When not specified, it defaults to all parameters.

A ``with_transpose`` clause specifies the transpose function associated
with the original function. When a ``with_transpose`` clause is not specified,
the mandatory differentiation transform will add a ``with_transpose`` clause to
the instruction.

In raw SIL, it is optional to provide a transpose function ``with`` clause.
In canonical SIL, a ``with`` clause is mandatory.


differentiable_function_extract
```````````````````````````````

::

sil-instruction ::= 'differentiable_function_extract'
sil-differentiable-function-extractee
sil-differentiable-function-order?
sil-value ':' sil-type

sil-differentiable-function-extractee ::=
'[' sil-differentiable-function-extractee ']'
sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
Expand All @@ -5664,6 +5689,27 @@ Extracts the original function or a derivative function from the given
``[original]``, ``[jvp]`` or ``[vjp]``.


linear_function_extract
```````````````````````

::

sil-instruction ::= 'linear_function_extract'
sil-linear-function-extractee
sil-value ':' sil-type

sil-linear-function-extractee ::=
'[' sil-linear-function-extractee ']'
sil-linear-function-extractee-name ::= 'original' | 'jvp' | 'vjp'

linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T

Extracts the original function or a transpose function from the given
``@differentiable(linear)`` function. It must be provided with an extractee:
``[original]`` or ``[transpose]``.


Assertion configuration
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
8 changes: 6 additions & 2 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ class TypeAttributes {
bool isValid() const { return AtLoc.isValid(); }

// SWIFT_ENABLE_TENSORFLOW
bool isLinear() const { return linear; }

bool isLinear() const {
assert(!linear || (linear && has(TAK_differentiable)) &&
"Linear shouldn't have been true if there's no `@differentiable`");
return linear;
}

void clearAttribute(TypeAttrKind A) {
AttrLocs[A] = SourceLoc();
}
Expand Down
24 changes: 17 additions & 7 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class SILFunctionType;
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
enum class SILLinkage : uint8_t;

enum class DifferentiabilityKind : uint8_t {
NonDifferentiable = 0b00,
Normal = 0b01,
Linear = 0b11
enum class DifferentiabilityKind: uint8_t {
NonDifferentiable = 0,
Normal = 1,
Linear = 2
};

// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
Expand All @@ -52,9 +52,19 @@ enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
VJP = 2
};

enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
Original = 0,
Transpose = 1
struct LinearDifferentiableFunctionTypeComponent {
enum innerty : unsigned {
Original = 0,
Transpose = 1,
} rawValue;

LinearDifferentiableFunctionTypeComponent() = default;
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
: rawValue(rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
operator innerty() const { return rawValue; }
};

class ParsedAutoDiffParameter {
Expand Down
8 changes: 6 additions & 2 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1614,8 +1614,12 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
"expected '}' to start a derivative function list", ())
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
"the number of operand lists does not match the order", ())
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
"expected a derivative function kind attribute, e.g. '[jvp]'", ())
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
"expected an extractee kind attribute, which can be one of '[original]', "
"'[jvp]', and '[vjp]'", ())
ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
"expected an extractee kind attribute, which can be one of '[original]' "
"and '[transpose]'", ())
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
"expected an operand of a function type", ())

Expand Down
2 changes: 0 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2913,8 +2913,6 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
"overriding declaration is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))
ERROR(linear_differentiable_type_disabled,none,
"'@differentiable(linear)' types are not yet supported", ())

// @differentiating
ERROR(differentiating_attr_expected_result_tuple,none,
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4216,7 +4216,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

// SWIFT_ENABLE_TENSORFLOW
CanSILFunctionType getWithDifferentiability(
IndexSubset *parameterIndices);
DifferentiabilityKind kind, IndexSubset *parameterIndices);

CanSILFunctionType getWithoutDifferentiability();

Expand Down
15 changes: 15 additions & 0 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,14 @@ class SILBuilder {
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
}

LinearFunctionInst *createLinearFunction(
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
Optional<SILValue> TransposeFunction) {
return insert(LinearFunctionInst::create(
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, TransposeFunction, hasOwnership()));
}

DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
Expand All @@ -526,6 +534,13 @@ class SILBuilder {
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
}

LinearFunctionExtractInst *createLinearFunctionExtract(
SILLocation Loc, LinearDifferentiableFunctionTypeComponent Extractee,
SILValue TheFunction) {
return insert(new (getModule()) LinearFunctionExtractInst(
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
}

DifferentiableFunctionExtractInst *
createDifferentiableFunctionExtractOriginal(SILLocation Loc,
SILValue TheFunction) {
Expand Down
22 changes: 22 additions & 0 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,18 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
getOpValue(Inst->getOriginalFunction()), derivativeFns));
}

template<typename ImplClass>
void SILCloner<ImplClass>::visitLinearFunctionInst(LinearFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
auto transpose = Inst->getOptionalTransposeFunction();
if (transpose)
transpose = getOpValue(*transpose);
recordClonedInstruction(
Inst, getBuilder().createLinearFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
getOpValue(Inst->getOriginalFunction()), transpose));
}

template<typename ImplClass>
void SILCloner<ImplClass>::
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst) {
Expand All @@ -989,6 +1001,16 @@ visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *Inst)
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
getOpValue(Inst->getFunctionOperand())));
}

template<typename ImplClass>
void SILCloner<ImplClass>::
visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
recordClonedInstruction(
Inst, getBuilder().createLinearFunctionExtract(
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
getOpValue(Inst->getFunctionOperand())));
}
// SWIFT_ENABLE_TENSORFLOW END

template<typename ImplClass>
Expand Down
83 changes: 79 additions & 4 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -7845,7 +7845,7 @@ class TryApplyInst final
};

// SWIFT_ENABLE_TENSORFLOW
/// `differentiable_function` - given a function and differentiation indices and
/// `differentiable_function` - given a function, differentiation indices and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Oxford comma please (used below for linear_function) 🙂

Suggested change
/// `differentiable_function` - given a function, differentiation indices and
/// `differentiable_function` - given a function, differentiation indices, and

/// its derivative functions, create an `@differentiable` function that
/// represents a bundle of these functions and configurations.
class DifferentiableFunctionInst final :
Expand All @@ -7865,14 +7865,14 @@ class DifferentiableFunctionInst final :
bool HasOwnership);

static SILType getDifferentiableFunctionType(
SILValue Original, IndexSubset *ParameterIndices);
SILValue OriginalFunction, IndexSubset *ParameterIndices);

static ValueOwnershipKind getMergedOwnershipKind(
SILValue Original, ArrayRef<SILValue> DerivativeFunctions);
SILValue OriginalFunction, ArrayRef<SILValue> DerivativeFunctions);

public:
static DifferentiableFunctionInst *create(
SILModule &Module, SILDebugLocation DebugLoc,
SILModule &Module, SILDebugLocation Loc,
IndexSubset *ParameterIndices, SILValue OriginalFunction,
Optional<std::pair<SILValue, SILValue>> VJPAndJVPFunctions,
bool HasOwnership);
Expand Down Expand Up @@ -7920,6 +7920,46 @@ class DifferentiableFunctionInst final :
}
};

/// `linear_function` - given a function, differentiation parameter indices,
/// result indices, and its derivative functions, create an `@differentiable`
/// function that represents a bundle of these functions and configurations.
class LinearFunctionInst final :
public InstructionBaseWithTrailingOperands<
SILInstructionKind::LinearFunctionInst,
LinearFunctionInst, OwnershipForwardingSingleValueInst> {
private:
friend SILBuilder;
/// Parameters to differentiate with respect to.
IndexSubset *ParameterIndices;
/// Indicates whether a transpose function exists.
bool HasTransposeFunction;

static SILType getLinearFunctionType(
SILValue OriginalFunction, IndexSubset *ParameterIndices);

public:
LinearFunctionInst(SILDebugLocation Loc, IndexSubset *ParameterIndices,
SILValue OriginalFunction,
Optional<SILValue> TransposeFunction, bool HasOwnership);

static LinearFunctionInst *create(SILModule &Module, SILDebugLocation Loc,
IndexSubset *ParameterIndices,
SILValue OriginalFunction,
Optional<SILValue> TransposeFunction,
bool HasOwnership);

IndexSubset *getParameterIndices() const { return ParameterIndices; }
bool hasTransposeFunction() const { return HasTransposeFunction; }
SILValue getOriginalFunction() const { return getOperand(0); }
Optional<SILValue> getOptionalTransposeFunction() const {
return HasTransposeFunction ? Optional<SILValue>(getOperand(1)) : None;
}
SILValue getTransposeFunction() const {
assert(HasTransposeFunction);
return getOperand(1);
}
};

/// `differentiable_function_extract` - given an `@differentiable` function
/// representing a bundle of the original function and derivative functions,
/// extract the specified function.
Expand Down Expand Up @@ -7974,6 +8014,41 @@ class DifferentiableFunctionExtractInst

typedef DifferentiableFunctionExtractInst::Extractee
DifferentiableFunctionExtractee;

/// `linear_function_extract` - given an `@differentiable(linear)` function
/// representing a bundle of the original function and the transpose function,
/// extract the specified function.
class LinearFunctionExtractInst
: public InstructionBase<
SILInstructionKind::LinearFunctionExtractInst,
SingleValueInstruction> {
private:
/// The extractee.
LinearDifferentiableFunctionTypeComponent extractee;
/// The list containing the `@differentiable(linear)` function operand.
FixedOperandList<1> operands;

static SILType
getExtracteeType(SILValue function,
LinearDifferentiableFunctionTypeComponent extractee,
SILModule &module);

public:
explicit LinearFunctionExtractInst(
SILModule &module, SILDebugLocation debugLoc,
LinearDifferentiableFunctionTypeComponent extractee,
SILValue theFunction);

LinearDifferentiableFunctionTypeComponent getExtractee() const {
return extractee;
}

SILValue getFunctionOperand() const { return operands[0].get(); }
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
};

typedef LinearDifferentiableFunctionTypeComponent LinearFunctionExtractee;
// SWIFT_ENABLE_TENSORFLOW END

// This is defined out of line to work around the fact that this depends on
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SIL/SILNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,14 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
// Differentiable programming
SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(LinearFunctionInst, linear_function,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(DifferentiableFunctionExtractInst,
differentiable_function_extract,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(LinearFunctionExtractInst,
linear_function_extract,
SingleValueInstruction, None, DoesNotRelease)

// Key paths
// TODO: The only "side effect" is potentially retaining the returned key path
Expand Down
Loading