Skip to content

Merge branch 'tensorflow' into tensorflow-merge #27745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c76bde9
Defines derivatives for remaining tgmath math functions. (#27559)
vguerra Oct 9, 2019
36676ae
[AutoDiff] Remap `apply` callee type in derivative context. (#27590)
dan-zheng Oct 9, 2019
fb6045c
[AutoDiff] Underscore the `Differentiable` protocol. (#27577)
dan-zheng Oct 9, 2019
eeeeee2
[AutoDiff] Remove differentiation order from AD-related instructions.…
rxwei Oct 10, 2019
0aee08a
[AutoDiff] Rename "assocFn" to "derivativeFn" everywhere except Diffe…
bgogul Oct 10, 2019
431cc43
[AutoDiff] Rename "associated function" to "derivative function". (#2…
rxwei Oct 10, 2019
ced939b
[AutoDiff] Fix subset parameters thunk `partial_apply` substitutions.…
dan-zheng Oct 10, 2019
9886b63
[AutoDiff] Rename 'AutoDiffIndexSubset' to 'IndexSubset'. (#27615)
rxwei Oct 11, 2019
68b96a2
[AutoDiff] Type-check `@differentiable` attributes during validation.…
dan-zheng Oct 11, 2019
483a378
[AutoDiff] NFC: Move some 'IndexSubset' method impls to their own fil…
rxwei Oct 11, 2019
7e03973
[AutoDiff] Fix 'autodiff_function_extract' operand ownership kind. (#…
rxwei Oct 12, 2019
65ebee0
[AutoDiff] Fix `@differentiable(linear)` type parsing ambiguity. (#27…
rxwei Oct 12, 2019
9c76a29
[AutoDiff] Add `@differentiable` attribute SILGen assertion. (#27650)
dan-zheng Oct 13, 2019
a5dc918
[AutoDiff] Add SIL differentiability witnesses. (#27487)
dan-zheng Oct 13, 2019
98f3545
[NFC] [AutoDiff] Gardening. (#27651)
dan-zheng Oct 13, 2019
db95e53
[AutoDiff] [ASTGen] Check for 'linear' when generating 'AttributedTyp…
rxwei Oct 14, 2019
5e52226
[AutoDiff] [Serialization] Fix '@differentiable(linear)' SIL function…
rxwei Oct 14, 2019
4a4fe83
[AutoDiff] [IRGen] Lower `@differentiable(linear)` function types. (#…
rxwei Oct 14, 2019
740b63e
[AutoDiff] Introduce 'linear_function' and 'linear_function_extract' …
rxwei Oct 15, 2019
bb67311
[AutoDiff] Diagnose unsupported forward-mode control flow. (#27684)
dan-zheng Oct 15, 2019
504b794
[AutoDiff] Support '@differentiable(linear)' function conversion. (#2…
rxwei Oct 15, 2019
1e4552c
[AutoDiff] NFC: Change `DifferentiableFunctionExtractee` to a top-lev…
rxwei Oct 15, 2019
9a55bc3
[AutoDiff] [Docs] NFC: Fix 'linear_function' syntax documentation. (#…
rxwei Oct 15, 2019
b987d64
[AutoDiff] Fix build failure introduced in #27688. (#27691)
rxwei Oct 15, 2019
bafacd8
[AutoDiff] [SIL] Tweak 'differentiable_function' syntax. (#27689)
rxwei Oct 15, 2019
0d17ddf
[AutoDiff] Destroy all pullback indirect results after adjoint accumu…
rxwei Oct 16, 2019
76729c4
[AutoDiff] SILGen differentiability witnesses. (#27652)
dan-zheng Oct 16, 2019
d93efc1
[AutoDiff] Add differentiability witness SILGen test. (#27717)
dan-zheng Oct 16, 2019
117cf39
Merge branch 'tensorflow' into tensorflow-merge
asuhan Oct 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 76 additions & 33 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5607,40 +5607,64 @@ differentiable_function

sil-instruction ::= 'differentiable_function'
sil-differentiable-function-parameter-indices?
sil-differentiable-function-order?
sil-value ':' sil-type
sil-differentiable-function-associated-functions-clause?
sil-differentiable-function-derivative-functions-clause?
sil-differentiable-function-parameter-indices ::=
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
sil-differentiable-associated-functions-clause ::=
'with' sil-differentiable-associated-function-list
(',' sil-differentiable-associated-function-list)*
sil-differentiable-function-associated-function-list ::=
'{' sil-value ',' sil-value '}'

differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}

Bundles a function with its associated differentiation functions up to a
specified differentiation order into an ``@differentiable`` function. There are
two associated functions per differentiation order: a Jacobian-vector products
(JVP) function and a vector-Jacobian products (VJP) function.

``[wrt ...]`` specifies parameter indices that the original function is
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
sil-differentiable-derivative-functions-clause ::=
'with_derivative'
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'

differentiable_function [parameters 0] %0 : $(T) -> T \
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}

Bundles a function with its derivative functions into a ``@differentiable``
function. There are two derivative functions: a Jacobian-vector products (JVP)
function and a vector-Jacobian products (VJP) function.

``[parameters ...]`` specifies parameter indices that the original function is
differentiable with respect to. When not specified, it defaults to all
parameters.

``[order ...]`` specifies the maximum differentiation order for the resulting
function. The number of lists of associated functions is equal to the order.
A ``with_derivative`` clause specifies the differentiation functions associated
with the original function. When a ``with_derivative`` clause is not specified,
the first operand will be differentiated to produce derivative functions, and a
``with_derivative`` clause will be added to the instruction.

In raw SIL, it is optional to provide a derivative function ``with_derivative``
clause. In canonical SIL, a ``with_derivative`` clause is mandatory.


linear_function
```````````````

::

sil-instruction ::= 'linear_function'
sil-linear-function-parameter-indices?
sil-value ':' sil-type
sil-linear-function-transpose-function-clause?

sil-linear-function-parameter-indices ::=
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
sil-linear-transpose-function-clause ::=
with_transpose sil-value ':' sil-type

linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T

A ``with`` clause specifies the differentiation functions associated
with the original function. When a ``with`` clause is not specified, the first
operand will be differentiated to produce associated functions, and a ``with``
clause will be added to the instruction.
Bundles a function with its transpose function into a
``@differentiable(linear)`` function.

In raw SIL, it is optional to provide an associated function ``with`` clause.
``[parameters ...]`` specifies parameter indices that the original function is
linear with respect to. When not specified, it defaults to all parameters.

A ``with_transpose`` clause specifies the transpose function associated
with the original function. When a ``with_transpose`` clause is not specified,
the mandatory differentiation transform will add a ``with_transpose`` clause to
the instruction.

In raw SIL, it is optional to provide a transpose function ``with`` clause.
In canonical SIL, a ``with`` clause is mandatory.


Expand All @@ -5651,21 +5675,40 @@ differentiable_function_extract

sil-instruction ::= 'differentiable_function_extract'
sil-differentiable-function-extractee
sil-differentiable-function-order?
sil-value ':' sil-type

sil-differentiable-function-extractee ::=
'[' sil-differentiable-function-extractee ']'
sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T

Extracts the original function or a derivative function from the given
``@differentiable`` function. It must be provided with an extractee:
``[original]``, ``[jvp]`` or ``[vjp]``.


linear_function_extract
```````````````````````

::

sil-instruction ::= 'linear_function_extract'
sil-linear-function-extractee
sil-value ':' sil-type

sil-linear-function-extractee ::=
'[' sil-linear-function-extractee ']'
sil-linear-function-extractee-name ::= 'original' | 'transpose'

linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T

Extracts the original function or an associated function from the given
``@differentiable`` function at a specific differentiation order. It must be
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
Extracts the original function or a transpose function from the given
``@differentiable(linear)`` function. It must be provided with an extractee:
``[original]`` or ``[transpose]``.


Assertion configuration
Expand Down
6 changes: 2 additions & 4 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ namespace swift {
class TypeAliasDecl;
class VarDecl;
class UnifiedStatsReporter;
// SWIFT_ENABLE_TENSORFLOW
class AutoDiffIndexSubset;
class IndexSubset;
class VectorSpace;
class DifferentiableAttr;

Expand Down Expand Up @@ -280,8 +279,7 @@ class ASTContext final {
/// Cache of `@differentiable` attributes keyed by parameter indices. This
/// helps us diagnose multiple `@differentiable`s that are with respect to the
/// same set of parameters.
llvm::DenseMap<std::pair<Decl *, AutoDiffIndexSubset *>,
DifferentiableAttr *>
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
DifferentiableAttrs;

private:
Expand Down
30 changes: 19 additions & 11 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,31 @@ class ASTMangler : public Mangler {
ModuleDecl *Module);

// SWIFT_ENABLE_TENSORFLOW
// Mangle the autodiff associated function (JVP/VJP) with the given:
// - Mangled original function name.
// - Associated function kind.
// - Parameter/result indices.
std::string mangleAutoDiffAssociatedFunctionHelper(
StringRef name, AutoDiffAssociatedFunctionKind kind,
/// Mangle the derivative function (JVP/VJP) with the given:
/// - Mangled original function name.
/// - Derivative function kind.
/// - Parameter/result indices.
std::string mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
const SILAutoDiffIndices &indices);

// SWIFT_ENABLE_TENSORFLOW
// Mangle the autodiff linear map (differential/pullback) with the given:
// - Mangled original function name.
// - Linear map kind.
// - Parameter/result indices.
/// Mangle the autodiff linear map (differential/pullback) with the given:
/// - Mangled original function name.
/// - Linear map kind.
/// - Parameter/result indices.
std::string mangleAutoDiffLinearMapHelper(
StringRef name, AutoDiffLinearMapKind kind,
const SILAutoDiffIndices &indices);

/// Mangle a SIL differentiability witness key.
/// - Mangled original function name.
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
std::string mangleSILDifferentiabilityWitnessKey(
SILDifferentiabilityWitnessKey key);
// SWIFT_ENABLE_TENSORFLOW END

std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
GenericSignature signature,
CanType baseType,
Expand Down
58 changes: 31 additions & 27 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ class TypeAttributes {
bool isValid() const { return AtLoc.isValid(); }

// SWIFT_ENABLE_TENSORFLOW
bool isLinear() const { return linear; }

bool isLinear() const {
assert(!linear || (linear && has(TAK_differentiable)) &&
"Linear shouldn't have been true if there's no `@differentiable`");
return linear;
}

void clearAttribute(TypeAttrKind A) {
AttrLocs[A] = SourceLoc();
}
Expand Down Expand Up @@ -1537,8 +1541,8 @@ class DifferentiableAttr final
ParsedAutoDiffParameter> {
friend TrailingObjects;

/// Whether this function is linear (optional).
bool linear;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The JVP function.
Expand All @@ -1552,10 +1556,10 @@ class DifferentiableAttr final
/// specified.
FuncDecl *VJPFunction = nullptr;
/// The differentiation parameters' indices, resolved by the type checker.
AutoDiffIndexSubset *ParameterIndices = nullptr;
IndexSubset *ParameterIndices = nullptr;
/// The trailing where clause (optional).
TrailingWhereClause *WhereClause = nullptr;
/// The generic signature for autodiff associated functions. Resolved by the
/// The generic signature for autodiff derivative functions. Resolved by the
/// type checker based on the original function's generic signature and the
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
Expand All @@ -1571,7 +1575,7 @@ class DifferentiableAttr final

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, AutoDiffIndexSubset *indices,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenericSignature);
Expand All @@ -1587,7 +1591,7 @@ class DifferentiableAttr final

static DifferentiableAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear, AutoDiffIndexSubset *indices,
bool linear, IndexSubset *indices,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
GenericSignature derivativeGenSig);
Expand All @@ -1602,10 +1606,10 @@ class DifferentiableAttr final
/// registered VJP.
Optional<DeclNameWithLoc> getVJP() const { return VJP; }

AutoDiffIndexSubset *getParameterIndices() const {
IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(AutoDiffIndexSubset *pi) {
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
}

Expand All @@ -1620,8 +1624,8 @@ class DifferentiableAttr final
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

bool isLinear() const { return linear; }
bool isLinear() const { return Linear; }

TrailingWhereClause *getWhereClause() const { return WhereClause; }

Expand Down Expand Up @@ -1650,10 +1654,10 @@ class DifferentiableAttr final

// Print the attribute to the given stream.
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
// If `omitAssociatedFunctions` is true, omit printing associated functions.
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
void print(llvm::raw_ostream &OS, const Decl *D,
bool omitWrtClause = false,
bool omitAssociatedFunctions = false) const;
bool omitDerivativeFunctions = false) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand All @@ -1676,12 +1680,12 @@ class DifferentiatingAttr final
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
/// Whether this function is linear (optional).
bool linear;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
AutoDiffIndexSubset *ParameterIndices = nullptr;
IndexSubset *ParameterIndices = nullptr;

explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
Expand All @@ -1691,7 +1695,7 @@ class DifferentiatingAttr final
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
AutoDiffIndexSubset *indices);
IndexSubset *indices);

public:
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
Expand All @@ -1702,11 +1706,11 @@ class DifferentiatingAttr final
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
AutoDiffIndexSubset *indices);
IndexSubset *indices);

DeclNameWithLoc getOriginal() const { return Original; }

bool isLinear() const { return linear; }
bool isLinear() const { return Linear; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
Expand All @@ -1723,10 +1727,10 @@ class DifferentiatingAttr final
return NumParsedParameters;
}

AutoDiffIndexSubset *getParameterIndices() const {
IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(AutoDiffIndexSubset *pi) {
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
}

Expand Down Expand Up @@ -1757,7 +1761,7 @@ class TransposingAttr final
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
AutoDiffIndexSubset *ParameterIndexSubset = nullptr;
IndexSubset *ParameterIndexSubset = nullptr;

explicit TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
Expand All @@ -1767,7 +1771,7 @@ class TransposingAttr final
explicit TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
AutoDiffIndexSubset *indices);
IndexSubset *indices);

public:
static TransposingAttr *create(ASTContext &context, bool implicit,
Expand All @@ -1778,7 +1782,7 @@ class TransposingAttr final
static TransposingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
AutoDiffIndexSubset *indices);
IndexSubset *indices);

TypeRepr *getBaseType() const { return BaseType; }
DeclNameWithLoc getOriginal() const { return Original; }
Expand All @@ -1798,10 +1802,10 @@ class TransposingAttr final
return NumParsedParameters;
}

AutoDiffIndexSubset *getParameterIndexSubset() const {
IndexSubset *getParameterIndexSubset() const {
return ParameterIndexSubset;
}
void setParameterIndices(AutoDiffIndexSubset *pi) {
void setParameterIndices(IndexSubset *pi) {
ParameterIndexSubset = pi;
}

Expand Down
Loading