@@ -1631,7 +1631,6 @@ class OriginallyDefinedInAttr: public DeclAttribute {
1631
1631
}
1632
1632
};
1633
1633
1634
- <<<<<<< HEAD
1635
1634
// / Attribute that asks the compiler to generate a function that returns a
1636
1635
// / quoted representation of the attributed declaration.
1637
1636
// /
@@ -1691,7 +1690,8 @@ class QuotedAttr final : public DeclAttribute {
1691
1690
1692
1691
static QuotedAttr *create (ASTContext &context, FuncDecl *quoteDecl,
1693
1692
SourceLoc atLoc, SourceRange range, bool implicit);
1694
- =======
1693
+ };
1694
+
1695
1695
// / A declaration name with location.
1696
1696
struct DeclNameRefWithLoc {
1697
1697
DeclNameRef Name;
@@ -1918,7 +1918,6 @@ class DerivativeAttr final
1918
1918
static bool classof (const DeclAttribute *DA) {
1919
1919
return DA->getKind () == DAK_Derivative;
1920
1920
}
1921
- >>>>>>> upstream_20191216
1922
1921
};
1923
1922
1924
1923
// / Attributes that may be applied to declarations.
@@ -2100,224 +2099,6 @@ class DeclAttributes {
2100
2099
SourceLoc getStartLoc (bool forModifiers = false ) const ;
2101
2100
};
2102
2101
2103
- <<<<<<< HEAD
2104
- struct DeclNameWithLoc {
2105
- DeclName Name;
2106
- DeclNameLoc Loc;
2107
- };
2108
-
2109
- // / Attribute that marks a function as differentiable and optionally specifies
2110
- // / custom associated derivative functions: 'jvp' and 'vjp'.
2111
- // /
2112
- // / Examples:
2113
- // / @differentiable(jvp: jvpFoo where T : FloatingPoint)
2114
- // / @differentiable(wrt: (self, x, y), jvp: jvpFoo)
2115
- class DifferentiableAttr final
2116
- : public DeclAttribute,
2117
- private llvm::TrailingObjects<DifferentiableAttr,
2118
- ParsedAutoDiffParameter> {
2119
- friend TrailingObjects;
2120
- friend class DifferentiableAttributeParameterIndicesRequest ;
2121
-
2122
- // / The declaration on which the `@differentiable` attribute is declared.
2123
- Decl *OriginalDeclaration = nullptr ;
2124
- // / Whether this function is linear.
2125
- bool Linear;
2126
- // / The number of parsed parameters specified in 'wrt:'.
2127
- unsigned NumParsedParameters = 0 ;
2128
- // / The JVP function.
2129
- Optional<DeclNameWithLoc> JVP;
2130
- // / The VJP function.
2131
- Optional<DeclNameWithLoc> VJP;
2132
- // / The JVP function (optional), resolved by the type checker if JVP name is
2133
- // / specified.
2134
- FuncDecl *JVPFunction = nullptr ;
2135
- // / The VJP function (optional), resolved by the type checker if VJP name is
2136
- // / specified.
2137
- FuncDecl *VJPFunction = nullptr ;
2138
- // / The differentiation parameters' indices, resolved by the type checker.
2139
- // / The bit stores whether the parameter indices have been computed.
2140
- llvm::PointerIntPair<IndexSubset *, 1 , bool > ParameterIndicesAndBit;
2141
- // / The trailing where clause (optional).
2142
- TrailingWhereClause *WhereClause = nullptr ;
2143
- // / The generic signature for autodiff derivative functions. Resolved by the
2144
- // / type checker based on the original function's generic signature and the
2145
- // / attribute's where clause requirements. This is set only if the attribute
2146
- // / has a where clause.
2147
- GenericSignature DerivativeGenericSignature = GenericSignature();
2148
-
2149
- explicit DifferentiableAttr (bool implicit, SourceLoc atLoc,
2150
- SourceRange baseRange, bool linear,
2151
- ArrayRef<ParsedAutoDiffParameter> parameters,
2152
- Optional<DeclNameWithLoc> jvp,
2153
- Optional<DeclNameWithLoc> vjp,
2154
- TrailingWhereClause *clause);
2155
-
2156
- explicit DifferentiableAttr (Decl *original, bool implicit, SourceLoc atLoc,
2157
- SourceRange baseRange, bool linear,
2158
- IndexSubset *parameterIndices,
2159
- Optional<DeclNameWithLoc> jvp,
2160
- Optional<DeclNameWithLoc> vjp,
2161
- GenericSignature derivativeGenericSignature);
2162
-
2163
- public:
2164
- static DifferentiableAttr *create (ASTContext &context, bool implicit,
2165
- SourceLoc atLoc, SourceRange baseRange,
2166
- bool linear,
2167
- ArrayRef<ParsedAutoDiffParameter> params,
2168
- Optional<DeclNameWithLoc> jvp,
2169
- Optional<DeclNameWithLoc> vjp,
2170
- TrailingWhereClause *clause);
2171
-
2172
- static DifferentiableAttr *create (AbstractFunctionDecl *original,
2173
- bool implicit, SourceLoc atLoc,
2174
- SourceRange baseRange, bool linear,
2175
- IndexSubset *parameterIndices,
2176
- Optional<DeclNameWithLoc> jvp,
2177
- Optional<DeclNameWithLoc> vjp,
2178
- GenericSignature derivativeGenSig);
2179
-
2180
- Decl *getOriginalDeclaration () const { return OriginalDeclaration; }
2181
- void setOriginalDeclaration (Decl *decl);
2182
-
2183
- // / Get the optional 'jvp:' function name and location.
2184
- // / Use this instead of `getJVPFunction` to check whether the attribute has a
2185
- // / registered JVP.
2186
- Optional<DeclNameWithLoc> getJVP () const { return JVP; }
2187
-
2188
- // / Get the optional 'vjp:' function name and location.
2189
- // / Use this instead of `getVJPFunction` to check whether the attribute has a
2190
- // / registered VJP.
2191
- Optional<DeclNameWithLoc> getVJP () const { return VJP; }
2192
-
2193
- bool hasComputedParameterIndices () const ;
2194
- IndexSubset *getParameterIndices () const ;
2195
- void setParameterIndices (IndexSubset *paramIndices);
2196
-
2197
- // / The parsed differentiation parameters, i.e. the list of parameters
2198
- // / specified in 'wrt:'.
2199
- ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
2200
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2201
- }
2202
- MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
2203
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2204
- }
2205
- size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
2206
- return NumParsedParameters;
2207
- }
2208
-
2209
- bool isLinear () const { return Linear; }
2210
-
2211
- TrailingWhereClause *getWhereClause () const { return WhereClause; }
2212
-
2213
- GenericSignature getDerivativeGenericSignature () const {
2214
- return DerivativeGenericSignature;
2215
- }
2216
- void setDerivativeGenericSignature (GenericSignature derivativeGenSig) {
2217
- DerivativeGenericSignature = derivativeGenSig;
2218
- }
2219
-
2220
- FuncDecl *getJVPFunction () const { return JVPFunction; }
2221
- void setJVPFunction (FuncDecl *decl);
2222
- FuncDecl *getVJPFunction () const { return VJPFunction; }
2223
- void setVJPFunction (FuncDecl *decl);
2224
-
2225
- // / Get the derivative generic environment for the given `@differentiable`
2226
- // / attribute and original function.
2227
- GenericEnvironment *
2228
- getDerivativeGenericEnvironment (AbstractFunctionDecl *original) const ;
2229
-
2230
- // Print the attribute to the given stream.
2231
- // If `omitWrtClause` is true, omit printing the `wrt:` clause.
2232
- // If `omitDerivativeFunctions` is true, omit printing derivative functions.
2233
- void print (llvm::raw_ostream &OS, const Decl *D,
2234
- bool omitWrtClause = false ,
2235
- bool omitDerivativeFunctions = false ) const ;
2236
-
2237
- static bool classof (const DeclAttribute *DA) {
2238
- return DA->getKind () == DAK_Differentiable;
2239
- }
2240
- };
2241
-
2242
- // SWIFT_ENABLE_TENSORFLOW
2243
- // / Attribute that registers a function as a derivative of another function.
2244
- // /
2245
- // / Examples:
2246
- // / @derivative(of: sin(_:))
2247
- // / @derivative(of: +, wrt: (lhs, rhs))
2248
- class DerivativeAttr final
2249
- : public DeclAttribute,
2250
- private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
2251
- friend TrailingObjects;
2252
-
2253
- // / The original function name.
2254
- DeclNameWithLoc OriginalFunctionName;
2255
- // / The original function declaration, resolved by the type checker.
2256
- AbstractFunctionDecl *OriginalFunction = nullptr ;
2257
- // / The number of parsed parameters specified in 'wrt:'.
2258
- unsigned NumParsedParameters = 0 ;
2259
- // / The differentiation parameters' indices, resolved by the type checker.
2260
- IndexSubset *ParameterIndices = nullptr ;
2261
- // / The derivative function kind (JVP or VJP), resolved by the type checker.
2262
- Optional<AutoDiffDerivativeFunctionKind> Kind = None;
2263
-
2264
- explicit DerivativeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2265
- DeclNameWithLoc original,
2266
- ArrayRef<ParsedAutoDiffParameter> params);
2267
-
2268
- explicit DerivativeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2269
- DeclNameWithLoc original, IndexSubset *indices);
2270
-
2271
- public:
2272
- static DerivativeAttr *create (ASTContext &context, bool implicit,
2273
- SourceLoc atLoc, SourceRange baseRange,
2274
- DeclNameWithLoc original,
2275
- ArrayRef<ParsedAutoDiffParameter> params);
2276
-
2277
- static DerivativeAttr *create (ASTContext &context, bool implicit,
2278
- SourceLoc atLoc, SourceRange baseRange,
2279
- DeclNameWithLoc original, IndexSubset *indices);
2280
-
2281
- DeclNameWithLoc getOriginalFunctionName () const {
2282
- return OriginalFunctionName;
2283
- }
2284
- AbstractFunctionDecl *getOriginalFunction () const {
2285
- return OriginalFunction;
2286
- }
2287
- void setOriginalFunction (AbstractFunctionDecl *decl) {
2288
- OriginalFunction = decl;
2289
- }
2290
-
2291
- AutoDiffDerivativeFunctionKind getDerivativeKind () const {
2292
- assert (Kind && " Derivative function kind has not yet been resolved" );
2293
- return *Kind;
2294
- }
2295
- void setDerivativeKind (AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
2296
-
2297
- // / The parsed differentiation parameters, i.e. the list of parameters
2298
- // / specified in 'wrt:'.
2299
- ArrayRef<ParsedAutoDiffParameter> getParsedParameters () const {
2300
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2301
- }
2302
- MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters () {
2303
- return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2304
- }
2305
- size_t numTrailingObjects (OverloadToken<ParsedAutoDiffParameter>) const {
2306
- return NumParsedParameters;
2307
- }
2308
-
2309
- IndexSubset *getParameterIndices () const {
2310
- return ParameterIndices;
2311
- }
2312
- void setParameterIndices (IndexSubset *parameterIndices) {
2313
- ParameterIndices = parameterIndices;
2314
- }
2315
-
2316
- static bool classof (const DeclAttribute *DA) {
2317
- return DA->getKind () == DAK_Derivative;
2318
- }
2319
- };
2320
-
2321
2102
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
2322
2103
using DifferentiatingAttr = DerivativeAttr;
2323
2104
@@ -2337,7 +2118,7 @@ class TransposeAttr final
2337
2118
// / resolve the original declaration, which is serialized.
2338
2119
TypeRepr *BaseTypeRepr;
2339
2120
// / The original function name.
2340
- DeclNameWithLoc OriginalFunctionName;
2121
+ DeclNameRefWithLoc OriginalFunctionName;
2341
2122
// / The original function declaration, resolved by the type checker.
2342
2123
AbstractFunctionDecl *OriginalFunction = nullptr ;
2343
2124
// / The number of parsed parameters specified in 'wrt:'.
@@ -2346,26 +2127,26 @@ class TransposeAttr final
2346
2127
IndexSubset *ParameterIndices = nullptr ;
2347
2128
2348
2129
explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2349
- TypeRepr *baseTypeRepr, DeclNameWithLoc original,
2130
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2350
2131
ArrayRef<ParsedAutoDiffParameter> params);
2351
2132
2352
2133
explicit TransposeAttr (bool implicit, SourceLoc atLoc, SourceRange baseRange,
2353
- TypeRepr *baseTypeRepr, DeclNameWithLoc original,
2134
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2354
2135
IndexSubset *indices);
2355
2136
2356
2137
public:
2357
2138
static TransposeAttr *create (ASTContext &context, bool implicit,
2358
2139
SourceLoc atLoc, SourceRange baseRange,
2359
- TypeRepr *baseTypeRepr, DeclNameWithLoc original,
2140
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2360
2141
ArrayRef<ParsedAutoDiffParameter> params);
2361
2142
2362
2143
static TransposeAttr *create (ASTContext &context, bool implicit,
2363
2144
SourceLoc atLoc, SourceRange baseRange,
2364
- TypeRepr *baseTypeRepr, DeclNameWithLoc original,
2145
+ TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2365
2146
IndexSubset *indices);
2366
2147
2367
2148
TypeRepr *getBaseTypeRepr () const { return BaseTypeRepr; }
2368
- DeclNameWithLoc getOriginalFunctionName () const {
2149
+ DeclNameRefWithLoc getOriginalFunctionName () const {
2369
2150
return OriginalFunctionName;
2370
2151
}
2371
2152
AbstractFunctionDecl *getOriginalFunction () const {
@@ -2399,8 +2180,6 @@ class TransposeAttr final
2399
2180
}
2400
2181
};
2401
2182
2402
- =======
2403
- >>>>>>> upstream_20191216
2404
2183
void simple_display (llvm::raw_ostream &out, const DeclAttribute *attr);
2405
2184
2406
2185
inline SourceLoc extractNearestSourceLoc (const DeclAttribute *attr) {
0 commit comments