Skip to content

Commit 26c8f07

Browse files
authored
Merge pull request #27745 from apple/asuhan/tensorflow-merge
Merge branch 'tensorflow' into tensorflow-merge
2 parents cf1906c + 117cf39 commit 26c8f07

File tree

135 files changed

+5038
-2332
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

135 files changed

+5038
-2332
lines changed

docs/SIL.rst

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5607,40 +5607,64 @@ differentiable_function
56075607

56085608
sil-instruction ::= 'differentiable_function'
56095609
sil-differentiable-function-parameter-indices?
5610-
sil-differentiable-function-order?
56115610
sil-value ':' sil-type
5612-
sil-differentiable-function-associated-functions-clause?
5611+
sil-differentiable-function-derivative-functions-clause?
56135612
56145613
sil-differentiable-function-parameter-indices ::=
5615-
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
5616-
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
5617-
sil-differentiable-associated-functions-clause ::=
5618-
'with' sil-differentiable-associated-function-list
5619-
(',' sil-differentiable-associated-function-list)*
5620-
sil-differentiable-function-associated-function-list ::=
5621-
'{' sil-value ',' sil-value '}'
5622-
5623-
differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
5624-
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
5625-
5626-
Bundles a function with its associated differentiation functions up to a
5627-
specified differentiation order into an ``@differentiable`` function. There are
5628-
two associated functions per differentiation order: a Jacobian-vector products
5629-
(JVP) function and a vector-Jacobian products (VJP) function.
5630-
5631-
``[wrt ...]`` specifies parameter indices that the original function is
5614+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5615+
sil-differentiable-derivative-functions-clause ::=
5616+
'with_derivative'
5617+
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
5618+
5619+
differentiable_function [parameters 0] %0 : $(T) -> T \
5620+
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
5621+
5622+
Bundles a function with its derivative functions into a ``@differentiable``
5623+
function. There are two derivative functions: a Jacobian-vector products (JVP)
5624+
function and a vector-Jacobian products (VJP) function.
5625+
5626+
``[parameters ...]`` specifies parameter indices that the original function is
56325627
differentiable with respect to. When not specified, it defaults to all
56335628
parameters.
56345629

5635-
``[order ...]`` specifies the maximum differentiation order for the resulting
5636-
function. The number of lists of associated functions is equal to the order.
5630+
A ``with_derivative`` clause specifies the differentiation functions associated
5631+
with the original function. When a ``with_derivative`` clause is not specified,
5632+
the first operand will be differentiated to produce derivative functions, and a
5633+
``with_derivative`` clause will be added to the instruction.
5634+
5635+
In raw SIL, it is optional to provide a derivative function ``with_derivative``
5636+
clause. In canonical SIL, a ``with_derivative`` clause is mandatory.
5637+
5638+
5639+
linear_function
5640+
```````````````
5641+
5642+
::
5643+
5644+
sil-instruction ::= 'linear_function'
5645+
sil-linear-function-parameter-indices?
5646+
sil-value ':' sil-type
5647+
sil-linear-function-transpose-function-clause?
5648+
5649+
sil-linear-function-parameter-indices ::=
5650+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5651+
sil-linear-transpose-function-clause ::=
5652+
with_transpose sil-value ':' sil-type
5653+
5654+
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T
56375655

5638-
A ``with`` clause specifies the differentiation functions associated
5639-
with the original function. When a ``with`` clause is not specified, the first
5640-
operand will be differentiated to produce associated functions, and a ``with``
5641-
clause will be added to the instruction.
5656+
Bundles a function with its transpose function into a
5657+
``@differentiable(linear)`` function.
56425658

5643-
In raw SIL, it is optional to provide an associated function ``with`` clause.
5659+
``[parameters ...]`` specifies parameter indices that the original function is
5660+
linear with respect to. When not specified, it defaults to all parameters.
5661+
5662+
A ``with_transpose`` clause specifies the transpose function associated
5663+
with the original function. When a ``with_transpose`` clause is not specified,
5664+
the mandatory differentiation transform will add a ``with_transpose`` clause to
5665+
the instruction.
5666+
5667+
In raw SIL, it is optional to provide a transpose function ``with`` clause.
56445668
In canonical SIL, a ``with`` clause is mandatory.
56455669

56465670

@@ -5651,21 +5675,40 @@ differentiable_function_extract
56515675

56525676
sil-instruction ::= 'differentiable_function_extract'
56535677
sil-differentiable-function-extractee
5654-
sil-differentiable-function-order?
56555678
sil-value ':' sil-type
56565679

56575680
sil-differentiable-function-extractee ::=
56585681
'[' sil-differentiable-function-extractee ']'
56595682
sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
5660-
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'
56615683

56625684
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5663-
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
5664-
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
5685+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5686+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
5687+
5688+
Extracts the original function or a derivative function from the given
5689+
``@differentiable`` function. It must be provided with an extractee:
5690+
``[original]``, ``[jvp]`` or ``[vjp]``.
5691+
5692+
5693+
linear_function_extract
5694+
```````````````````````
5695+
5696+
::
5697+
5698+
sil-instruction ::= 'linear_function_extract'
5699+
sil-linear-function-extractee
5700+
sil-value ':' sil-type
5701+
5702+
sil-linear-function-extractee ::=
5703+
'[' sil-linear-function-extractee ']'
5704+
sil-linear-function-extractee-name ::= 'original' | 'transpose'
5705+
5706+
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
5707+
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
56655708

5666-
Extracts the original function or an associated function from the given
5667-
``@differentiable`` function at a specific differentiation order. It must be
5668-
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
5709+
Extracts the original function or a transpose function from the given
5710+
``@differentiable(linear)`` function. It must be provided with an extractee:
5711+
``[original]`` or ``[transpose]``.
56695712

56705713

56715714
Assertion configuration

include/swift/AST/ASTContext.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ namespace swift {
109109
class TypeAliasDecl;
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112-
// SWIFT_ENABLE_TENSORFLOW
113-
class AutoDiffIndexSubset;
112+
class IndexSubset;
114113
class VectorSpace;
115114
class DifferentiableAttr;
116115

@@ -280,8 +279,7 @@ class ASTContext final {
280279
/// Cache of `@differentiable` attributes keyed by parameter indices. This
281280
/// helps us diagnose multiple `@differentiable`s that are with respect to the
282281
/// same set of parameters.
283-
llvm::DenseMap<std::pair<Decl *, AutoDiffIndexSubset *>,
284-
DifferentiableAttr *>
282+
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
285283
DifferentiableAttrs;
286284

287285
private:

include/swift/AST/ASTMangler.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,31 @@ class ASTMangler : public Mangler {
155155
ModuleDecl *Module);
156156

157157
// SWIFT_ENABLE_TENSORFLOW
158-
// Mangle the autodiff associated function (JVP/VJP) with the given:
159-
// - Mangled original function name.
160-
// - Associated function kind.
161-
// - Parameter/result indices.
162-
std::string mangleAutoDiffAssociatedFunctionHelper(
163-
StringRef name, AutoDiffAssociatedFunctionKind kind,
158+
/// Mangle the derivative function (JVP/VJP) with the given:
159+
/// - Mangled original function name.
160+
/// - Derivative function kind.
161+
/// - Parameter/result indices.
162+
std::string mangleAutoDiffDerivativeFunctionHelper(
163+
StringRef name, AutoDiffDerivativeFunctionKind kind,
164164
const SILAutoDiffIndices &indices);
165165

166-
// SWIFT_ENABLE_TENSORFLOW
167-
// Mangle the autodiff linear map (differential/pullback) with the given:
168-
// - Mangled original function name.
169-
// - Linear map kind.
170-
// - Parameter/result indices.
166+
/// Mangle the autodiff linear map (differential/pullback) with the given:
167+
/// - Mangled original function name.
168+
/// - Linear map kind.
169+
/// - Parameter/result indices.
171170
std::string mangleAutoDiffLinearMapHelper(
172171
StringRef name, AutoDiffLinearMapKind kind,
173172
const SILAutoDiffIndices &indices);
174173

174+
/// Mangle a SIL differentiability witness key.
175+
/// - Mangled original function name.
176+
/// - Parameter indices.
177+
/// - Result indices.
178+
/// - Derivative generic signature (optional).
179+
std::string mangleSILDifferentiabilityWitnessKey(
180+
SILDifferentiabilityWitnessKey key);
181+
// SWIFT_ENABLE_TENSORFLOW END
182+
175183
std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
176184
GenericSignature signature,
177185
CanType baseType,

include/swift/AST/Attr.h

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ class TypeAttributes {
8989
bool isValid() const { return AtLoc.isValid(); }
9090

9191
// SWIFT_ENABLE_TENSORFLOW
92-
bool isLinear() const { return linear; }
93-
92+
bool isLinear() const {
93+
assert(!linear || (linear && has(TAK_differentiable)) &&
94+
"Linear shouldn't have been true if there's no `@differentiable`");
95+
return linear;
96+
}
97+
9498
void clearAttribute(TypeAttrKind A) {
9599
AttrLocs[A] = SourceLoc();
96100
}
@@ -1537,8 +1541,8 @@ class DifferentiableAttr final
15371541
ParsedAutoDiffParameter> {
15381542
friend TrailingObjects;
15391543

1540-
/// Whether this function is linear (optional).
1541-
bool linear;
1544+
/// Whether this function is linear.
1545+
bool Linear;
15421546
/// The number of parsed parameters specified in 'wrt:'.
15431547
unsigned NumParsedParameters = 0;
15441548
/// The JVP function.
@@ -1552,10 +1556,10 @@ class DifferentiableAttr final
15521556
/// specified.
15531557
FuncDecl *VJPFunction = nullptr;
15541558
/// The differentiation parameters' indices, resolved by the type checker.
1555-
AutoDiffIndexSubset *ParameterIndices = nullptr;
1559+
IndexSubset *ParameterIndices = nullptr;
15561560
/// The trailing where clause (optional).
15571561
TrailingWhereClause *WhereClause = nullptr;
1558-
/// The generic signature for autodiff associated functions. Resolved by the
1562+
/// The generic signature for autodiff derivative functions. Resolved by the
15591563
/// type checker based on the original function's generic signature and the
15601564
/// attribute's where clause requirements. This is set only if the attribute
15611565
/// has a where clause.
@@ -1571,7 +1575,7 @@ class DifferentiableAttr final
15711575

15721576
explicit DifferentiableAttr(ASTContext &context, bool implicit,
15731577
SourceLoc atLoc, SourceRange baseRange,
1574-
bool linear, AutoDiffIndexSubset *indices,
1578+
bool linear, IndexSubset *indices,
15751579
Optional<DeclNameWithLoc> jvp,
15761580
Optional<DeclNameWithLoc> vjp,
15771581
GenericSignature derivativeGenericSignature);
@@ -1587,7 +1591,7 @@ class DifferentiableAttr final
15871591

15881592
static DifferentiableAttr *create(ASTContext &context, bool implicit,
15891593
SourceLoc atLoc, SourceRange baseRange,
1590-
bool linear, AutoDiffIndexSubset *indices,
1594+
bool linear, IndexSubset *indices,
15911595
Optional<DeclNameWithLoc> jvp,
15921596
Optional<DeclNameWithLoc> vjp,
15931597
GenericSignature derivativeGenSig);
@@ -1602,10 +1606,10 @@ class DifferentiableAttr final
16021606
/// registered VJP.
16031607
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
16041608

1605-
AutoDiffIndexSubset *getParameterIndices() const {
1609+
IndexSubset *getParameterIndices() const {
16061610
return ParameterIndices;
16071611
}
1608-
void setParameterIndices(AutoDiffIndexSubset *pi) {
1612+
void setParameterIndices(IndexSubset *pi) {
16091613
ParameterIndices = pi;
16101614
}
16111615

@@ -1620,8 +1624,8 @@ class DifferentiableAttr final
16201624
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
16211625
return NumParsedParameters;
16221626
}
1623-
1624-
bool isLinear() const { return linear; }
1627+
1628+
bool isLinear() const { return Linear; }
16251629

16261630
TrailingWhereClause *getWhereClause() const { return WhereClause; }
16271631

@@ -1650,10 +1654,10 @@ class DifferentiableAttr final
16501654

16511655
// Print the attribute to the given stream.
16521656
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1653-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1657+
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
16541658
void print(llvm::raw_ostream &OS, const Decl *D,
16551659
bool omitWrtClause = false,
1656-
bool omitAssociatedFunctions = false) const;
1660+
bool omitDerivativeFunctions = false) const;
16571661

16581662
static bool classof(const DeclAttribute *DA) {
16591663
return DA->getKind() == DAK_Differentiable;
@@ -1676,12 +1680,12 @@ class DifferentiatingAttr final
16761680
DeclNameWithLoc Original;
16771681
/// The original function, resolved by the type checker.
16781682
FuncDecl *OriginalFunction = nullptr;
1679-
/// Whether this function is linear (optional).
1680-
bool linear;
1683+
/// Whether this function is linear.
1684+
bool Linear;
16811685
/// The number of parsed parameters specified in 'wrt:'.
16821686
unsigned NumParsedParameters = 0;
16831687
/// The differentiation parameters' indices, resolved by the type checker.
1684-
AutoDiffIndexSubset *ParameterIndices = nullptr;
1688+
IndexSubset *ParameterIndices = nullptr;
16851689

16861690
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16871691
SourceLoc atLoc, SourceRange baseRange,
@@ -1691,7 +1695,7 @@ class DifferentiatingAttr final
16911695
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16921696
SourceLoc atLoc, SourceRange baseRange,
16931697
DeclNameWithLoc original, bool linear,
1694-
AutoDiffIndexSubset *indices);
1698+
IndexSubset *indices);
16951699

16961700
public:
16971701
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
@@ -1702,11 +1706,11 @@ class DifferentiatingAttr final
17021706
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
17031707
SourceLoc atLoc, SourceRange baseRange,
17041708
DeclNameWithLoc original, bool linear,
1705-
AutoDiffIndexSubset *indices);
1709+
IndexSubset *indices);
17061710

17071711
DeclNameWithLoc getOriginal() const { return Original; }
17081712

1709-
bool isLinear() const { return linear; }
1713+
bool isLinear() const { return Linear; }
17101714

17111715
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
17121716
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
@@ -1723,10 +1727,10 @@ class DifferentiatingAttr final
17231727
return NumParsedParameters;
17241728
}
17251729

1726-
AutoDiffIndexSubset *getParameterIndices() const {
1730+
IndexSubset *getParameterIndices() const {
17271731
return ParameterIndices;
17281732
}
1729-
void setParameterIndices(AutoDiffIndexSubset *pi) {
1733+
void setParameterIndices(IndexSubset *pi) {
17301734
ParameterIndices = pi;
17311735
}
17321736

@@ -1757,7 +1761,7 @@ class TransposingAttr final
17571761
/// The number of parsed parameters specified in 'wrt:'.
17581762
unsigned NumParsedParameters = 0;
17591763
/// The differentiation parameters' indices, resolved by the type checker.
1760-
AutoDiffIndexSubset *ParameterIndexSubset = nullptr;
1764+
IndexSubset *ParameterIndexSubset = nullptr;
17611765

17621766
explicit TransposingAttr(ASTContext &context, bool implicit,
17631767
SourceLoc atLoc, SourceRange baseRange,
@@ -1767,7 +1771,7 @@ class TransposingAttr final
17671771
explicit TransposingAttr(ASTContext &context, bool implicit,
17681772
SourceLoc atLoc, SourceRange baseRange,
17691773
TypeRepr *baseType, DeclNameWithLoc original,
1770-
AutoDiffIndexSubset *indices);
1774+
IndexSubset *indices);
17711775

17721776
public:
17731777
static TransposingAttr *create(ASTContext &context, bool implicit,
@@ -1778,7 +1782,7 @@ class TransposingAttr final
17781782
static TransposingAttr *create(ASTContext &context, bool implicit,
17791783
SourceLoc atLoc, SourceRange baseRange,
17801784
TypeRepr *baseType, DeclNameWithLoc original,
1781-
AutoDiffIndexSubset *indices);
1785+
IndexSubset *indices);
17821786

17831787
TypeRepr *getBaseType() const { return BaseType; }
17841788
DeclNameWithLoc getOriginal() const { return Original; }
@@ -1798,10 +1802,10 @@ class TransposingAttr final
17981802
return NumParsedParameters;
17991803
}
18001804

1801-
AutoDiffIndexSubset *getParameterIndexSubset() const {
1805+
IndexSubset *getParameterIndexSubset() const {
18021806
return ParameterIndexSubset;
18031807
}
1804-
void setParameterIndices(AutoDiffIndexSubset *pi) {
1808+
void setParameterIndices(IndexSubset *pi) {
18051809
ParameterIndexSubset = pi;
18061810
}
18071811

0 commit comments

Comments
 (0)