Skip to content

Commit ab304f1

Browse files
committed
Merge branch 'tensorflow' of github.com:apple/swift into sil-differentiability-witness
2 parents 6b78684 + 431cc43 commit ab304f1

File tree

95 files changed

+2040
-1864
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+2040
-1864
lines changed

docs/SIL.rst

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,38 +5609,34 @@ differentiable_function
56095609
sil-differentiable-function-parameter-indices?
56105610
sil-differentiable-function-order?
56115611
sil-value ':' sil-type
5612-
sil-differentiable-function-associated-functions-clause?
5612+
sil-differentiable-function-derivative-functions-clause?
56135613
56145614
sil-differentiable-function-parameter-indices ::=
56155615
'[' 'wrt' [0-9]+ (',', [0-9]+)* ']'
56165616
sil-differentiable-function-order ::= '[' 'order' [0-9]+ ']'
5617-
sil-differentiable-associated-functions-clause ::=
5618-
'with' sil-differentiable-associated-function-list
5619-
(',' sil-differentiable-associated-function-list)*
5620-
sil-differentiable-function-associated-function-list ::=
5617+
sil-differentiable-derivative-functions-clause ::=
5618+
'with' sil-differentiable-derivative-function-list
5619+
(',' sil-differentiable-derivative-function-list)*
5620+
sil-differentiable-function-derivative-function-list ::=
56215621
'{' sil-value ',' sil-value '}'
56225622

5623-
differentiable_function [wrt 0] [order 1] %0 : $(T) -> T \
5623+
differentiable_function [wrt 0] %0 : $(T) -> T \
56245624
with {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
56255625

5626-
Bundles a function with its associated differentiation functions up to a
5627-
specified differentiation order into an ``@differentiable`` function. There are
5628-
two associated functions per differentiation order: a Jacobian-vector products
5629-
(JVP) function and a vector-Jacobian products (VJP) function.
5626+
Bundles a function with its derivative functions into a ``@differentiable``
5627+
function. There are two derivative functions: a Jacobian-vector products (JVP)
5628+
function and a vector-Jacobian products (VJP) function.
56305629

56315630
``[wrt ...]`` specifies parameter indices that the original function is
56325631
differentiable with respect to. When not specified, it defaults to all
56335632
parameters.
56345633

5635-
``[order ...]`` specifies the maximum differentiation order for the resulting
5636-
function. The number of lists of associated functions is equal to the order.
5637-
56385634
A ``with`` clause specifies the differentiation functions associated
56395635
with the original function. When a ``with`` clause is not specified, the first
5640-
operand will be differentiated to produce associated functions, and a ``with``
5636+
operand will be differentiated to produce derivative functions, and a ``with``
56415637
clause will be added to the instruction.
56425638

5643-
In raw SIL, it is optional to provide an associated function ``with`` clause.
5639+
In raw SIL, it is optional to provide a derivative function ``with`` clause.
56445640
In canonical SIL, a ``with`` clause is mandatory.
56455641

56465642

@@ -5660,12 +5656,12 @@ differentiable_function_extract
56605656
sil-differentiable-function-differentiation-order ::= '[' 'order' [0-9]+ ']'
56615657

56625658
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5663-
differentiable_function_extract [jvp] [order 1] %0 : $@differentiable (T) -> T
5664-
differentiable_function_extract [vjp] [order 1] %0 : $@differentiable (T) -> T
5659+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5660+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
56655661

5666-
Extracts the original function or an associated function from the given
5667-
``@differentiable`` function at a specific differentiation order. It must be
5668-
provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``.
5662+
Extracts the original function or a derivative function from the given
5663+
``@differentiable`` function. It must be provided with an extractee:
5664+
``[original]``, ``[jvp]`` or ``[vjp]``.
56695665

56705666

56715667
Assertion configuration

include/swift/AST/ASTContext.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ namespace swift {
110110
class VarDecl;
111111
class UnifiedStatsReporter;
112112
// SWIFT_ENABLE_TENSORFLOW
113+
class AutoDiffIndexSubset;
113114
class VectorSpace;
114-
class AutoDiffParameterIndices;
115115
class DifferentiableAttr;
116116

117117
enum class KnownProtocolKind : uint8_t;
@@ -280,8 +280,9 @@ class ASTContext final {
280280
/// Cache of `@differentiable` attributes keyed by parameter indices. This
281281
/// helps us diagnose multiple `@differentiable`s that are with respect to the
282282
/// same set of parameters.
283-
llvm::DenseMap<std::pair<Decl *, AutoDiffParameterIndices *>,
284-
DifferentiableAttr *> DifferentiableAttrs;
283+
llvm::DenseMap<std::pair<Decl *, AutoDiffIndexSubset *>,
284+
DifferentiableAttr *>
285+
DifferentiableAttrs;
285286

286287
private:
287288
/// The current generation number, which reflects the number of

include/swift/AST/ASTMangler.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ class ASTMangler : public Mangler {
155155
ModuleDecl *Module);
156156

157157
// SWIFT_ENABLE_TENSORFLOW
158-
// Mangle the autodiff associated function (JVP/VJP) with the given:
158+
// Mangle the derivative function (JVP/VJP) with the given:
159159
// - Mangled original function name.
160-
// - Associated function kind.
160+
// - Derivative function kind.
161161
// - Parameter/result indices.
162-
std::string mangleAutoDiffAssociatedFunctionHelper(
163-
StringRef name, AutoDiffAssociatedFunctionKind kind,
162+
std::string mangleAutoDiffDerivativeFunctionHelper(
163+
StringRef name, AutoDiffDerivativeFunctionKind kind,
164164
const SILAutoDiffIndices &indices);
165165

166166
// SWIFT_ENABLE_TENSORFLOW

include/swift/AST/Attr.h

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,10 +1552,10 @@ class DifferentiableAttr final
15521552
/// specified.
15531553
FuncDecl *VJPFunction = nullptr;
15541554
/// The differentiation parameters' indices, resolved by the type checker.
1555-
AutoDiffParameterIndices *ParameterIndices = nullptr;
1555+
AutoDiffIndexSubset *ParameterIndices = nullptr;
15561556
/// The trailing where clause (optional).
15571557
TrailingWhereClause *WhereClause = nullptr;
1558-
/// The generic signature for autodiff associated functions. Resolved by the
1558+
/// The generic signature for autodiff derivative functions. Resolved by the
15591559
/// type checker based on the original function's generic signature and the
15601560
/// attribute's where clause requirements. This is set only if the attribute
15611561
/// has a where clause.
@@ -1571,8 +1571,7 @@ class DifferentiableAttr final
15711571

15721572
explicit DifferentiableAttr(ASTContext &context, bool implicit,
15731573
SourceLoc atLoc, SourceRange baseRange,
1574-
bool linear,
1575-
AutoDiffParameterIndices *indices,
1574+
bool linear, AutoDiffIndexSubset *indices,
15761575
Optional<DeclNameWithLoc> jvp,
15771576
Optional<DeclNameWithLoc> vjp,
15781577
GenericSignature *derivativeGenericSignature);
@@ -1588,8 +1587,7 @@ class DifferentiableAttr final
15881587

15891588
static DifferentiableAttr *create(ASTContext &context, bool implicit,
15901589
SourceLoc atLoc, SourceRange baseRange,
1591-
bool linear,
1592-
AutoDiffParameterIndices *indices,
1590+
bool linear, AutoDiffIndexSubset *indices,
15931591
Optional<DeclNameWithLoc> jvp,
15941592
Optional<DeclNameWithLoc> vjp,
15951593
GenericSignature *derivativeGenSig);
@@ -1604,10 +1602,10 @@ class DifferentiableAttr final
16041602
/// registered VJP.
16051603
Optional<DeclNameWithLoc> getVJP() const { return VJP; }
16061604

1607-
AutoDiffParameterIndices *getParameterIndices() const {
1605+
AutoDiffIndexSubset *getParameterIndices() const {
16081606
return ParameterIndices;
16091607
}
1610-
void setParameterIndices(AutoDiffParameterIndices *pi) {
1608+
void setParameterIndices(AutoDiffIndexSubset *pi) {
16111609
ParameterIndices = pi;
16121610
}
16131611

@@ -1642,7 +1640,7 @@ class DifferentiableAttr final
16421640

16431641
bool parametersMatch(const DifferentiableAttr &other) const {
16441642
assert(ParameterIndices && other.ParameterIndices);
1645-
return ParameterIndices->parameters == other.ParameterIndices->parameters;
1643+
return ParameterIndices == other.ParameterIndices;
16461644
}
16471645

16481646
/// Get the derivative generic environment for the given `@differentiable`
@@ -1652,10 +1650,10 @@ class DifferentiableAttr final
16521650

16531651
// Print the attribute to the given stream.
16541652
// If `omitWrtClause` is true, omit printing the `wrt:` clause.
1655-
// If `omitAssociatedFunctions` is true, omit printing associated functions.
1653+
// If `omitDerivativeFunctions` is true, omit printing derivative functions.
16561654
void print(llvm::raw_ostream &OS, const Decl *D,
16571655
bool omitWrtClause = false,
1658-
bool omitAssociatedFunctions = false) const;
1656+
bool omitDerivativeFunctions = false) const;
16591657

16601658
static bool classof(const DeclAttribute *DA) {
16611659
return DA->getKind() == DAK_Differentiable;
@@ -1683,7 +1681,7 @@ class DifferentiatingAttr final
16831681
/// The number of parsed parameters specified in 'wrt:'.
16841682
unsigned NumParsedParameters = 0;
16851683
/// The differentiation parameters' indices, resolved by the type checker.
1686-
AutoDiffParameterIndices *ParameterIndices = nullptr;
1684+
AutoDiffIndexSubset *ParameterIndices = nullptr;
16871685

16881686
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16891687
SourceLoc atLoc, SourceRange baseRange,
@@ -1693,7 +1691,7 @@ class DifferentiatingAttr final
16931691
explicit DifferentiatingAttr(ASTContext &context, bool implicit,
16941692
SourceLoc atLoc, SourceRange baseRange,
16951693
DeclNameWithLoc original, bool linear,
1696-
AutoDiffParameterIndices *indices);
1694+
AutoDiffIndexSubset *indices);
16971695

16981696
public:
16991697
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
@@ -1704,7 +1702,7 @@ class DifferentiatingAttr final
17041702
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
17051703
SourceLoc atLoc, SourceRange baseRange,
17061704
DeclNameWithLoc original, bool linear,
1707-
AutoDiffParameterIndices *indices);
1705+
AutoDiffIndexSubset *indices);
17081706

17091707
DeclNameWithLoc getOriginal() const { return Original; }
17101708

@@ -1725,10 +1723,10 @@ class DifferentiatingAttr final
17251723
return NumParsedParameters;
17261724
}
17271725

1728-
AutoDiffParameterIndices *getParameterIndices() const {
1726+
AutoDiffIndexSubset *getParameterIndices() const {
17291727
return ParameterIndices;
17301728
}
1731-
void setParameterIndices(AutoDiffParameterIndices *pi) {
1729+
void setParameterIndices(AutoDiffIndexSubset *pi) {
17321730
ParameterIndices = pi;
17331731
}
17341732

0 commit comments

Comments
 (0)