@@ -1709,7 +1709,15 @@ class DifferentiableAttr final
1709
1709
private llvm::TrailingObjects<DifferentiableAttr,
1710
1710
ParsedAutoDiffParameter> {
1711
1711
friend TrailingObjects;
1712
-
1712
+ // SWIFT_ENABLE_TENSORFLOW
1713
+ friend class DifferentiableAttributeParameterIndicesRequest ;
1714
+ // SWIFT_ENABLE_TENSORFLOW END
1715
+
1716
+ // SWIFT_ENABLE_TENSORFLOW
1717
+ // / The declaration on which the `@differentiable` attribute is declared.
1718
+ // / Resolved during parsing and deserialization.
1719
+ Decl *OriginalDeclaration = nullptr ;
1720
+ // SWIFT_ENABLE_TENSORFLOW END
1713
1721
// / Whether this function is linear (optional).
1714
1722
bool Linear;
1715
1723
// / The number of parsed parameters specified in 'wrt:'.
@@ -1724,8 +1732,13 @@ class DifferentiableAttr final
1724
1732
// / The VJP function (optional), resolved by the type checker if VJP name is
1725
1733
// / specified.
1726
1734
FuncDecl *VJPFunction = nullptr ;
1735
+ // SWIFT_ENABLE_TENSORFLOW
1736
+ // NOTE: Parameter indices requestification is done on `tensorflow` branch but
1737
+ // has not yet been upstreamed to `master` branch.
1727
1738
// / The differentiation parameters' indices, resolved by the type checker.
1728
- IndexSubset *ParameterIndices = nullptr ;
1739
+ // / The bit stores whether the parameter indices have been computed.
1740
+ llvm::PointerIntPair<IndexSubset *, 1 , bool > ParameterIndicesAndBit;
1741
+ // SWIFT_ENABLE_TENSORFLOW END
1729
1742
// / The trailing where clause (optional).
1730
1743
TrailingWhereClause *WhereClause = nullptr ;
1731
1744
// / The generic signature for autodiff associated functions. Resolved by the
@@ -1765,6 +1778,9 @@ class DifferentiableAttr final
1765
1778
Optional<DeclNameRefWithLoc> vjp,
1766
1779
GenericSignature derivativeGenSig);
1767
1780
1781
+ Decl *getOriginalDeclaration () const { return OriginalDeclaration; }
1782
+ void setOriginalDeclaration (Decl *decl);
1783
+
1768
1784
// / Get the optional 'jvp:' function name and location.
1769
1785
// / Use this instead of `getJVPFunction` to check whether the attribute has a
1770
1786
// / registered JVP.
@@ -1775,12 +1791,13 @@ class DifferentiableAttr final
1775
1791
// / registered VJP.
1776
1792
Optional<DeclNameRefWithLoc> getVJP () const { return VJP; }
1777
1793
1778
- IndexSubset *getParameterIndices () const {
1779
- return ParameterIndices;
1780
- }
1781
- void setParameterIndices (IndexSubset *parameterIndices) {
1782
- ParameterIndices = parameterIndices;
1783
- }
1794
+ // SWIFT_ENABLE_TENSORFLOW
1795
+ // NOTE: Parameter indices requestification is done on `tensorflow` branch but
1796
+ // has not yet been upstreamed to `master` branch.
1797
+ bool hasComputedParameterIndices () const ;
1798
+ IndexSubset *getParameterIndices () const ;
1799
+ void setParameterIndices (IndexSubset *paramIndices);
1800
+ // SWIFT_ENABLE_TENSORFLOW END
1784
1801
1785
1802
// / The parsed differentiation parameters, i.e. the list of parameters
1786
1803
// / specified in 'wrt:'.
@@ -1920,6 +1937,89 @@ class DerivativeAttr final
1920
1937
}
1921
1938
};
1922
1939
1940
+ // SWIFT_ENABLE_TENSORFLOW
1941
+ // TODO(TF-999): Remove deprecated `@differentiating` attribute.
1942
+ using DifferentiatingAttr = DerivativeAttr;
1943
+
1944
+ // / Attribute that registers a function as a transpose of another function.
1945
+ // /
1946
+ // / Examples:
1947
+ // / @transpose(of: foo)
1948
+ // / @transpose(of: +, wrt: (lhs, rhs))
1949
+ class TransposeAttr final
1950
+ : public DeclAttribute,
1951
+ private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
1952
+ friend TrailingObjects;
1953
+
1954
+ // / The base type for the referenced original declaration. This field is
1955
+ // / non-null only for parsed attributes that reference a qualified original
1956
+ // / declaration. This field is not serialized; type-checking uses it to
1957
+ // / resolve the original declaration, which is serialized.
1958
+ TypeRepr *BaseTypeRepr;
1959
+ // / The original function name.
1960
+ DeclNameRefWithLoc OriginalFunctionName;
1961
+ // / The original function declaration, resolved by the type checker.
1962
+ AbstractFunctionDecl *OriginalFunction = nullptr ;
1963
+ // / The number of parsed parameters specified in 'wrt:'.
1964
+ unsigned NumParsedParameters = 0 ;
1965
+ // / The transposed parameters' indices, resolved by the type checker.
1966
+ IndexSubset *ParameterIndices = nullptr ;
1967
+
1968
+ explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
1969
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1970
+ ArrayRef<ParsedAutoDiffParameter> params);
1971
+
1972
+ explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
1973
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1974
+ IndexSubset *indices);
1975
+
1976
+ public:
1977
+ static TransposeAttr *create (ASTContext &context, bool implicit,
1978
+ SourceLoc atLoc, SourceRange baseRange,
1979
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1980
+ ArrayRef<ParsedAutoDiffParameter> params);
1981
+
1982
+ static TransposeAttr *create (ASTContext &context, bool implicit,
1983
+ SourceLoc atLoc, SourceRange baseRange,
1984
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1985
+ IndexSubset *indices);
1986
+
1987
+ TypeRepr *getBaseTypeRepr () const { return BaseTypeRepr; }
1988
+ DeclNameRefWithLoc getOriginalFunctionName () const {
1989
+ return OriginalFunctionName;
1990
+ }
1991
+ AbstractFunctionDecl *getOriginalFunction () const {
1992
+ return OriginalFunction;
1993
+ }
1994
+ void setOriginalFunction (AbstractFunctionDecl *decl) {
1995
+ OriginalFunction = decl;
1996
+ }
1997
+
1998
+ // / The parsed transposed parameters, i.e. the list of parameters specified in
1999
+ // / 'wrt:'.
2000
+ ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
2001
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2002
+ }
2003
+ MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
2004
+ return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2005
+ }
2006
+ size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
2007
+ return NumParsedParameters;
2008
+ }
2009
+
2010
+ IndexSubset *getParameterIndices () const {
2011
+ return ParameterIndices;
2012
+ }
2013
+ void setParameterIndices (IndexSubset *parameterIndices) {
2014
+ ParameterIndices = parameterIndices;
2015
+ }
2016
+
2017
+ static bool classof (const DeclAttribute *DA) {
2018
+ return DA->getKind () == DAK_Transpose;
2019
+ }
2020
+ };
2021
+ // SWIFT_ENABLE_TENSORFLOW END
2022
+
1923
2023
// / Attributes that may be applied to declarations.
1924
2024
class DeclAttributes {
1925
2025
// / Linked list of declaration attributes.
@@ -2099,87 +2199,6 @@ class DeclAttributes {
2099
2199
SourceLoc getStartLoc (bool forModifiers = false ) const ;
2100
2200
};
2101
2201
2102
- // TODO(TF-999): Remove deprecated `@differentiating` attribute.
2103
- using DifferentiatingAttr = DerivativeAttr;
2104
-
2105
- // / Attribute that registers a function as a transpose of another function.
2106
- // /
2107
- // / Examples:
2108
- // / @transpose(of: foo)
2109
- // / @transpose(of: +, wrt: (lhs, rhs))
2110
- class TransposeAttr final
2111
- : public DeclAttribute,
2112
- private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
2113
- friend TrailingObjects;
2114
-
2115
- // / The base type for the referenced original declaration. This field is
2116
- // / non-null only for parsed attributes that reference a qualified original
2117
- // / declaration. This field is not serialized; type-checking uses it to
2118
- // / resolve the original declaration, which is serialized.
2119
- TypeRepr *BaseTypeRepr;
2120
- // / The original function name.
2121
- DeclNameRefWithLoc OriginalFunctionName;
2122
- // / The original function declaration, resolved by the type checker.
2123
- AbstractFunctionDecl *OriginalFunction = nullptr ;
2124
- // / The number of parsed parameters specified in 'wrt:'.
2125
- unsigned NumParsedParameters = 0 ;
2126
- // / The transposed parameters' indices, resolved by the type checker.
2127
- IndexSubset *ParameterIndices = nullptr ;
2128
-
2129
- explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2130
- TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2131
- ArrayRef<ParsedAutoDiffParameter> params);
2132
-
2133
- explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2134
- TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2135
- IndexSubset *indices);
2136
-
2137
- public:
2138
- static TransposeAttr *create (ASTContext &context, bool implicit,
2139
- SourceLoc atLoc, SourceRange baseRange,
2140
- TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2141
- ArrayRef<ParsedAutoDiffParameter> params);
2142
-
2143
- static TransposeAttr *create (ASTContext &context, bool implicit,
2144
- SourceLoc atLoc, SourceRange baseRange,
2145
- TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2146
- IndexSubset *indices);
2147
-
2148
- TypeRepr *getBaseTypeRepr () const { return BaseTypeRepr; }
2149
- DeclNameRefWithLoc getOriginalFunctionName () const {
2150
- return OriginalFunctionName;
2151
- }
2152
- AbstractFunctionDecl *getOriginalFunction () const {
2153
- return OriginalFunction;
2154
- }
2155
- void setOriginalFunction (AbstractFunctionDecl *decl) {
2156
- OriginalFunction = decl;
2157
- }
2158
-
2159
- // / The parsed transposed parameters, i.e. the list of parameters specified in
2160
- // / 'wrt:'.
2161
- ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
2162
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2163
- }
2164
- MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
2165
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2166
- }
2167
- size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
2168
- return NumParsedParameters;
2169
- }
2170
-
2171
- IndexSubset *getParameterIndices () const {
2172
- return ParameterIndices;
2173
- }
2174
- void setParameterIndices (IndexSubset *parameterIndices) {
2175
- ParameterIndices = parameterIndices;
2176
- }
2177
-
2178
- static bool classof (const DeclAttribute *DA) {
2179
- return DA->getKind () == DAK_Transpose;
2180
- }
2181
- };
2182
-
2183
2202
void simple_display (llvm::raw_ostream &out, const DeclAttribute *attr);
2184
2203
2185
2204
inline SourceLoc extractNearestSourceLoc (const DeclAttribute *attr) {
0 commit comments