Skip to content

Commit addbf91

Browse files
committed
Resolve remaining merge conflicts.
1 parent f4fbe0b commit addbf91

33 files changed

+1845
-2968
lines changed

include/swift/AST/Attr.h

Lines changed: 108 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,15 @@ class DifferentiableAttr final
17091709
private llvm::TrailingObjects<DifferentiableAttr,
17101710
ParsedAutoDiffParameter> {
17111711
friend TrailingObjects;
1712-
1712+
// SWIFT_ENABLE_TENSORFLOW
1713+
friend class DifferentiableAttributeParameterIndicesRequest;
1714+
// SWIFT_ENABLE_TENSORFLOW END
1715+
1716+
// SWIFT_ENABLE_TENSORFLOW
1717+
/// The declaration on which the `@differentiable` attribute is declared.
1718+
/// Resolved during parsing and deserialization.
1719+
Decl *OriginalDeclaration = nullptr;
1720+
// SWIFT_ENABLE_TENSORFLOW END
17131721
/// Whether this function is linear (optional).
17141722
bool Linear;
17151723
/// The number of parsed parameters specified in 'wrt:'.
@@ -1724,8 +1732,13 @@ class DifferentiableAttr final
17241732
/// The VJP function (optional), resolved by the type checker if VJP name is
17251733
/// specified.
17261734
FuncDecl *VJPFunction = nullptr;
1735+
// SWIFT_ENABLE_TENSORFLOW
1736+
// NOTE: Parameter indices requestification is done on `tensorflow` branch but
1737+
// has not yet been upstreamed to `master` branch.
17271738
/// The differentiation parameters' indices, resolved by the type checker.
1728-
IndexSubset *ParameterIndices = nullptr;
1739+
/// The bit stores whether the parameter indices have been computed.
1740+
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
1741+
// SWIFT_ENABLE_TENSORFLOW END
17291742
/// The trailing where clause (optional).
17301743
TrailingWhereClause *WhereClause = nullptr;
17311744
/// The generic signature for autodiff associated functions. Resolved by the
@@ -1765,6 +1778,9 @@ class DifferentiableAttr final
17651778
Optional<DeclNameRefWithLoc> vjp,
17661779
GenericSignature derivativeGenSig);
17671780

1781+
Decl *getOriginalDeclaration() const { return OriginalDeclaration; }
1782+
void setOriginalDeclaration(Decl *decl);
1783+
17681784
/// Get the optional 'jvp:' function name and location.
17691785
/// Use this instead of `getJVPFunction` to check whether the attribute has a
17701786
/// registered JVP.
@@ -1775,12 +1791,13 @@ class DifferentiableAttr final
17751791
/// registered VJP.
17761792
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }
17771793

1778-
IndexSubset *getParameterIndices() const {
1779-
return ParameterIndices;
1780-
}
1781-
void setParameterIndices(IndexSubset *parameterIndices) {
1782-
ParameterIndices = parameterIndices;
1783-
}
1794+
// SWIFT_ENABLE_TENSORFLOW
1795+
// NOTE: Parameter indices requestification is done on `tensorflow` branch but
1796+
// has not yet been upstreamed to `master` branch.
1797+
bool hasComputedParameterIndices() const;
1798+
IndexSubset *getParameterIndices() const;
1799+
void setParameterIndices(IndexSubset *paramIndices);
1800+
// SWIFT_ENABLE_TENSORFLOW END
17841801

17851802
/// The parsed differentiation parameters, i.e. the list of parameters
17861803
/// specified in 'wrt:'.
@@ -1920,6 +1937,89 @@ class DerivativeAttr final
19201937
}
19211938
};
19221939

1940+
// SWIFT_ENABLE_TENSORFLOW
1941+
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
1942+
using DifferentiatingAttr = DerivativeAttr;
1943+
1944+
/// Attribute that registers a function as a transpose of another function.
1945+
///
1946+
/// Examples:
1947+
/// @transpose(of: foo)
1948+
/// @transpose(of: +, wrt: (lhs, rhs))
1949+
class TransposeAttr final
1950+
: public DeclAttribute,
1951+
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
1952+
friend TrailingObjects;
1953+
1954+
/// The base type for the referenced original declaration. This field is
1955+
/// non-null only for parsed attributes that reference a qualified original
1956+
/// declaration. This field is not serialized; type-checking uses it to
1957+
/// resolve the original declaration, which is serialized.
1958+
TypeRepr *BaseTypeRepr;
1959+
/// The original function name.
1960+
DeclNameRefWithLoc OriginalFunctionName;
1961+
/// The original function declaration, resolved by the type checker.
1962+
AbstractFunctionDecl *OriginalFunction = nullptr;
1963+
/// The number of parsed parameters specified in 'wrt:'.
1964+
unsigned NumParsedParameters = 0;
1965+
/// The transposed parameters' indices, resolved by the type checker.
1966+
IndexSubset *ParameterIndices = nullptr;
1967+
1968+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1969+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1970+
ArrayRef<ParsedAutoDiffParameter> params);
1971+
1972+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1973+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1974+
IndexSubset *indices);
1975+
1976+
public:
1977+
static TransposeAttr *create(ASTContext &context, bool implicit,
1978+
SourceLoc atLoc, SourceRange baseRange,
1979+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1980+
ArrayRef<ParsedAutoDiffParameter> params);
1981+
1982+
static TransposeAttr *create(ASTContext &context, bool implicit,
1983+
SourceLoc atLoc, SourceRange baseRange,
1984+
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
1985+
IndexSubset *indices);
1986+
1987+
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
1988+
DeclNameRefWithLoc getOriginalFunctionName() const {
1989+
return OriginalFunctionName;
1990+
}
1991+
AbstractFunctionDecl *getOriginalFunction() const {
1992+
return OriginalFunction;
1993+
}
1994+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1995+
OriginalFunction = decl;
1996+
}
1997+
1998+
/// The parsed transposed parameters, i.e. the list of parameters specified in
1999+
/// 'wrt:'.
2000+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
2001+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2002+
}
2003+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
2004+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2005+
}
2006+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
2007+
return NumParsedParameters;
2008+
}
2009+
2010+
IndexSubset *getParameterIndices() const {
2011+
return ParameterIndices;
2012+
}
2013+
void setParameterIndices(IndexSubset *parameterIndices) {
2014+
ParameterIndices = parameterIndices;
2015+
}
2016+
2017+
static bool classof(const DeclAttribute *DA) {
2018+
return DA->getKind() == DAK_Transpose;
2019+
}
2020+
};
2021+
// SWIFT_ENABLE_TENSORFLOW END
2022+
19232023
/// Attributes that may be applied to declarations.
19242024
class DeclAttributes {
19252025
/// Linked list of declaration attributes.
@@ -2099,87 +2199,6 @@ class DeclAttributes {
20992199
SourceLoc getStartLoc(bool forModifiers = false) const;
21002200
};
21012201

2102-
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
2103-
using DifferentiatingAttr = DerivativeAttr;
2104-
2105-
/// Attribute that registers a function as a transpose of another function.
2106-
///
2107-
/// Examples:
2108-
/// @transpose(of: foo)
2109-
/// @transpose(of: +, wrt: (lhs, rhs))
2110-
class TransposeAttr final
2111-
: public DeclAttribute,
2112-
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
2113-
friend TrailingObjects;
2114-
2115-
/// The base type for the referenced original declaration. This field is
2116-
/// non-null only for parsed attributes that reference a qualified original
2117-
/// declaration. This field is not serialized; type-checking uses it to
2118-
/// resolve the original declaration, which is serialized.
2119-
TypeRepr *BaseTypeRepr;
2120-
/// The original function name.
2121-
DeclNameRefWithLoc OriginalFunctionName;
2122-
/// The original function declaration, resolved by the type checker.
2123-
AbstractFunctionDecl *OriginalFunction = nullptr;
2124-
/// The number of parsed parameters specified in 'wrt:'.
2125-
unsigned NumParsedParameters = 0;
2126-
/// The transposed parameters' indices, resolved by the type checker.
2127-
IndexSubset *ParameterIndices = nullptr;
2128-
2129-
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
2130-
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2131-
ArrayRef<ParsedAutoDiffParameter> params);
2132-
2133-
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
2134-
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2135-
IndexSubset *indices);
2136-
2137-
public:
2138-
static TransposeAttr *create(ASTContext &context, bool implicit,
2139-
SourceLoc atLoc, SourceRange baseRange,
2140-
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2141-
ArrayRef<ParsedAutoDiffParameter> params);
2142-
2143-
static TransposeAttr *create(ASTContext &context, bool implicit,
2144-
SourceLoc atLoc, SourceRange baseRange,
2145-
TypeRepr *baseTypeRepr, DeclNameRefWithLoc original,
2146-
IndexSubset *indices);
2147-
2148-
TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; }
2149-
DeclNameRefWithLoc getOriginalFunctionName() const {
2150-
return OriginalFunctionName;
2151-
}
2152-
AbstractFunctionDecl *getOriginalFunction() const {
2153-
return OriginalFunction;
2154-
}
2155-
void setOriginalFunction(AbstractFunctionDecl *decl) {
2156-
OriginalFunction = decl;
2157-
}
2158-
2159-
/// The parsed transposed parameters, i.e. the list of parameters specified in
2160-
/// 'wrt:'.
2161-
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
2162-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2163-
}
2164-
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
2165-
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
2166-
}
2167-
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
2168-
return NumParsedParameters;
2169-
}
2170-
2171-
IndexSubset *getParameterIndices() const {
2172-
return ParameterIndices;
2173-
}
2174-
void setParameterIndices(IndexSubset *parameterIndices) {
2175-
ParameterIndices = parameterIndices;
2176-
}
2177-
2178-
static bool classof(const DeclAttribute *DA) {
2179-
return DA->getKind() == DAK_Transpose;
2180-
}
2181-
};
2182-
21832202
void simple_display(llvm::raw_ostream &out, const DeclAttribute *attr);
21842203

21852204
inline SourceLoc extractNearestSourceLoc(const DeclAttribute *attr) {

0 commit comments

Comments
 (0)