Skip to content

Commit dfa65b8

Browse files
committed
[DNM] [AutoDiff upstream] @differentiable attribute type-checking.
`@differentiable` attribute type-checking mega-patch. Upstreams all necessary code from `tensorflow` branch. This mega-patch is just for end-to-end testing. Code will be upstreamed incrementally.
1 parent 4f7349b commit dfa65b8

23 files changed

+2683
-13
lines changed

include/swift/AST/ASTContext.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace swift {
6464
class DeclContext;
6565
class DefaultArgumentInitializer;
6666
class DerivativeAttr;
67+
class DifferentiableAttr;
6768
class ExtensionDecl;
6869
class ForeignRepresentationInfo;
6970
class FuncDecl;
@@ -284,6 +285,14 @@ class ASTContext final {
284285
/// across invocations of both the parser and the type-checker.
285286
unsigned NextAutoClosureDiscriminator = 0;
286287

288+
/// Cached mapping from types to their associated tangent spaces.
289+
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
290+
291+
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
292+
/// diagnose duplicate `@differentiable` attributes for the same key.
293+
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
294+
DifferentiableAttrs;
295+
287296
/// Cache of `@derivative` attributes keyed by parameter indices and
288297
/// derivative function kind. Used to diagnose duplicate `@derivative`
289298
/// attributes for the same key.

include/swift/AST/ASTScope.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,41 @@ class SpecializeAttributeScope final : public ASTScopeImpl {
15681568
DeclConsumer) const override;
15691569
};
15701570

1571+
/// A `@differentiable` attribute scope.
1572+
///
1573+
/// This exists because `@differentiable` attribute may have a `where` clause
1574+
/// referring to generic parameters from some generic context.
1575+
class DifferentiableAttributeScope final : public ASTScopeImpl {
1576+
public:
1577+
DifferentiableAttr *const differentiableAttr;
1578+
ValueDecl *const attributedDeclaration;
1579+
1580+
DifferentiableAttributeScope(DifferentiableAttr *diffAttr, ValueDecl *decl)
1581+
: differentiableAttr(diffAttr), attributedDeclaration(decl) {}
1582+
virtual ~DifferentiableAttributeScope() {}
1583+
1584+
std::string getClassName() const override;
1585+
SourceRange
1586+
getSourceRangeOfThisASTNode(bool omitAssertions = false) const override;
1587+
NullablePtr<const void> addressForPrinting() const override {
1588+
return differentiableAttr;
1589+
}
1590+
1591+
NullablePtr<AbstractStorageDecl>
1592+
getEnclosingAbstractStorageDecl() const override;
1593+
1594+
NullablePtr<DeclAttribute> getDeclAttributeIfAny() const override {
1595+
return differentiableAttr;
1596+
}
1597+
NullablePtr<const void> getReferrent() const override;
1598+
1599+
protected:
1600+
ASTScopeImpl *expandSpecifically(ScopeCreator &) override;
1601+
bool lookupLocalsOrMembers(ArrayRef<const ASTScopeImpl *>,
1602+
DeclConsumer) const override;
1603+
bool doesContextMatchStartingContext(const DeclContext *) const override;
1604+
};
1605+
15711606
class SubscriptDeclScope final : public ASTScopeImpl {
15721607
public:
15731608
SubscriptDecl *const decl;

include/swift/AST/Attr.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,7 @@ class DifferentiableAttr final
16471647
private llvm::TrailingObjects<DifferentiableAttr,
16481648
ParsedAutoDiffParameter> {
16491649
friend TrailingObjects;
1650+
friend class DifferentiableAttributeTypeCheckRequest;
16501651

16511652
/// The declaration on which the `@differentiable` attribute is declared.
16521653
/// May not be a valid declaration for `@differentiable` attributes.
@@ -1667,7 +1668,8 @@ class DifferentiableAttr final
16671668
/// specified.
16681669
FuncDecl *VJPFunction = nullptr;
16691670
/// The differentiability parameter indices, resolved by the type checker.
1670-
IndexSubset *ParameterIndices = nullptr;
1671+
/// The bit stores whether the parameter indices have been computed.
1672+
llvm::PointerIntPair<IndexSubset *, 1, bool> ParameterIndicesAndBit;
16711673
/// The trailing where clause (optional).
16721674
TrailingWhereClause *WhereClause = nullptr;
16731675
/// The generic signature for autodiff associated functions. Resolved by the
@@ -1723,12 +1725,14 @@ class DifferentiableAttr final
17231725
/// registered VJP.
17241726
Optional<DeclNameRefWithLoc> getVJP() const { return VJP; }
17251727

1726-
IndexSubset *getParameterIndices() const {
1727-
return ParameterIndices;
1728-
}
1729-
void setParameterIndices(IndexSubset *parameterIndices) {
1730-
ParameterIndices = parameterIndices;
1731-
}
1728+
private:
1729+
/// Returns true if the given `@differentiable` attribute has been
1730+
/// type-checked.
1731+
bool hasBeenTypeChecked() const;
1732+
1733+
public:
1734+
IndexSubset *getParameterIndices() const;
1735+
void setParameterIndices(IndexSubset *parameterIndices);
17321736

17331737
/// The parsed differentiability parameters, i.e. the list of parameters
17341738
/// specified in 'wrt:'.

include/swift/AST/AutoDiff.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <cstdint>
2121

22+
#include "swift/AST/GenericSignature.h"
2223
#include "swift/AST/Identifier.h"
2324
#include "swift/AST/IndexSubset.h"
2425
#include "swift/AST/Type.h"
@@ -28,6 +29,7 @@
2829
namespace swift {
2930

3031
class AnyFunctionType;
32+
class TupleType;
3133

3234
/// A function type differentiability kind.
3335
enum class DifferentiabilityKind : uint8_t {
@@ -70,6 +72,26 @@ struct AutoDiffDerivativeFunctionKind {
7072
}
7173
};
7274

75+
/// Identifies an autodiff derivative function configuration:
76+
/// - Parameter indices.
77+
/// - Result indices.
78+
/// - Derivative generic signature (optional).
79+
struct AutoDiffConfig {
80+
IndexSubset *parameterIndices;
81+
IndexSubset *resultIndices;
82+
GenericSignature derivativeGenericSignature;
83+
84+
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
85+
IndexSubset *resultIndices,
86+
GenericSignature derivativeGenericSignature)
87+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
88+
derivativeGenericSignature(derivativeGenericSignature) {}
89+
90+
void print(llvm::raw_ostream &s = llvm::outs()) const;
91+
LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
92+
"only for use within the debugger");
93+
};
94+
7395
class ParsedAutoDiffParameter {
7496
public:
7597
enum class Kind { Named, Ordered, Self };
@@ -133,6 +155,91 @@ class ParsedAutoDiffParameter {
133155
}
134156
};
135157

158+
/// The tangent space of a type.
159+
///
160+
/// For `Differentiable`-conforming types:
161+
/// - The tangent space is the `TangentVector` associated type.
162+
///
163+
/// For tuple types:
164+
/// - The tangent space is a tuple of the elements' tangent space types, for the
165+
/// elements that have a tangent space.
166+
///
167+
/// For function types whose innermost result type has a tangent space:
168+
/// - The tangent space is the same function type, replacing the innermost
169+
/// result type with its tangent space type.
170+
///
171+
/// Other types have no tangent space.
172+
class TangentSpace {
173+
public:
174+
/// A tangent space kind.
175+
enum class Kind {
176+
/// The `TangentVector` associated type of a `Differentiable`-conforming
177+
/// type.
178+
TangentVector,
179+
/// A product of tangent spaces as a tuple.
180+
Tuple,
181+
/// A function type whose innermost result is the `TangentVector`
182+
/// associated type of a `Differentiable`-conforming type.
183+
Function
184+
};
185+
186+
private:
187+
Kind kind;
188+
union Value {
189+
// TangentVector
190+
Type tangentVectorType;
191+
// Tuple
192+
TupleType *tupleType;
193+
// Function
194+
AnyFunctionType *functionType;
195+
196+
Value(Type tangentVectorType) : tangentVectorType(tangentVectorType) {}
197+
Value(TupleType *tupleType) : tupleType(tupleType) {}
198+
Value(AnyFunctionType *functionType) : functionType(functionType) {}
199+
} value;
200+
201+
TangentSpace(Kind kind, Value value) : kind(kind), value(value) {}
202+
203+
public:
204+
TangentSpace() = delete;
205+
206+
static TangentSpace getTangentVector(Type tangentVectorType) {
207+
return {Kind::TangentVector, tangentVectorType};
208+
}
209+
static TangentSpace getTuple(TupleType *tupleTy) {
210+
return {Kind::Tuple, tupleTy};
211+
}
212+
static TangentSpace getFunction(AnyFunctionType *fnTy) {
213+
return {Kind::Function, fnTy};
214+
}
215+
216+
bool isTangentVector() const { return kind == Kind::TangentVector; }
217+
bool isTuple() const { return kind == Kind::Tuple; }
218+
219+
Kind getKind() const { return kind; }
220+
Type getTangentVector() const {
221+
assert(kind == Kind::TangentVector);
222+
return value.tangentVectorType;
223+
}
224+
TupleType *getTuple() const {
225+
assert(kind == Kind::Tuple);
226+
return value.tupleType;
227+
}
228+
AnyFunctionType *getFunction() const {
229+
assert(kind == Kind::Function);
230+
return value.functionType;
231+
}
232+
233+
/// Get the tangent space type.
234+
Type getType() const;
235+
236+
/// Get the tangent space canonical type.
237+
CanType getCanonicalType() const;
238+
239+
/// Get the underlying
240+
NominalTypeDecl *getNominal() const;
241+
};
242+
136243
/// Automatic differentiation utility namespace.
137244
namespace autodiff {
138245

@@ -148,10 +255,59 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
148255

149256
namespace llvm {
150257

258+
using swift::AutoDiffConfig;
151259
using swift::AutoDiffDerivativeFunctionKind;
260+
using swift::GenericSignature;
261+
using swift::IndexSubset;
152262

153263
template <typename T> struct DenseMapInfo;
154264

265+
template <> struct DenseMapInfo<AutoDiffConfig> {
266+
static AutoDiffConfig getEmptyKey() {
267+
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
268+
// The `derivativeGenericSignature` component must be `nullptr` so that
269+
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
270+
// an invalid pointer.
271+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
272+
nullptr};
273+
}
274+
275+
static AutoDiffConfig getTombstoneKey() {
276+
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
277+
// The `derivativeGenericSignature` component must be `nullptr` so that
278+
// `getHashValue` and `isEqual` do not try to `getCanonicalSignature()` on
279+
// an invalid pointer.
280+
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
281+
nullptr};
282+
}
283+
284+
static unsigned getHashValue(const AutoDiffConfig &Val) {
285+
auto canGenSig =
286+
Val.derivativeGenericSignature
287+
? Val.derivativeGenericSignature->getCanonicalSignature()
288+
: nullptr;
289+
unsigned combinedHash = hash_combine(
290+
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
291+
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
292+
DenseMapInfo<GenericSignature>::getHashValue(canGenSig));
293+
return combinedHash;
294+
}
295+
296+
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
297+
auto lhsCanGenSig =
298+
LHS.derivativeGenericSignature
299+
? LHS.derivativeGenericSignature->getCanonicalSignature()
300+
: nullptr;
301+
auto rhsCanGenSig =
302+
RHS.derivativeGenericSignature
303+
? RHS.derivativeGenericSignature->getCanonicalSignature()
304+
: nullptr;
305+
return LHS.parameterIndices == RHS.parameterIndices &&
306+
LHS.resultIndices == RHS.resultIndices &&
307+
DenseMapInfo<GenericSignature>::isEqual(lhsCanGenSig, rhsCanGenSig);
308+
}
309+
};
310+
155311
template <> struct DenseMapInfo<AutoDiffDerivativeFunctionKind> {
156312
static AutoDiffDerivativeFunctionKind getEmptyKey() {
157313
return static_cast<AutoDiffDerivativeFunctionKind::innerty>(

include/swift/AST/Decl.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5734,6 +5734,20 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
57345734
private:
57355735
ParameterList *Params;
57365736

5737+
/// The generation at which we last loaded derivative function configurations.
5738+
unsigned DerivativeFunctionConfigGeneration = 0;
5739+
/// Prepare to traverse the list of derivative function configurations.
5740+
void prepareDerivativeFunctionConfigurations();
5741+
5742+
/// A uniqued list of derivative function configurations.
5743+
/// - `@differentiable` and `@derivative` attribute type-checking is
5744+
/// responsible for populating derivative function configurations specified
5745+
/// in the current module.
5746+
/// - Module loading is responsible for populating derivative function
5747+
/// configurations from imported modules.
5748+
struct DerivativeFunctionConfigurationList;
5749+
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
5750+
57375751
protected:
57385752
// If a function has a body at all, we have either a parsed body AST node or
57395753
// we have saved the end location of the unparsed body.
@@ -6058,6 +6072,12 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
60586072
/// constructor.
60596073
bool hasDynamicSelfResult() const;
60606074

6075+
/// Get all derivative function configurations.
6076+
ArrayRef<AutoDiffConfig> getDerivativeFunctionConfigurations();
6077+
6078+
/// Add the given derivative function configuration.
6079+
void addDerivativeFunctionConfiguration(AutoDiffConfig config);
6080+
60616081
using DeclContext::operator new;
60626082
using Decl::getASTContext;
60636083
};

include/swift/AST/DiagnosticsSema.def

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,62 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29042904
"containing type %0 does not conform to protocol %1",
29052905
(DeclName, DeclName))
29062906

2907+
// @differentiable
2908+
ERROR(differentiable_attr_void_result,none,
2909+
"cannot differentiate void function %0", (DeclName))
2910+
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
2911+
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
2912+
"attribute for transpose registration instead", ())
2913+
ERROR(differentiable_attr_overload_not_found,none,
2914+
"%0 does not have expected type %1", (DeclNameRef, Type))
2915+
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
2916+
// mention "same generic requirements".
2917+
ERROR(differentiable_attr_duplicate,none,
2918+
"duplicate '@differentiable' attribute with same parameters", ())
2919+
NOTE(differentiable_attr_duplicate_note,none,
2920+
"other attribute declared here", ())
2921+
ERROR(differentiable_attr_function_not_same_type_context,none,
2922+
"%0 is not defined in the current type context", (DeclNameRef))
2923+
ERROR(differentiable_attr_derivative_not_function,none,
2924+
"registered derivative %0 must be a 'func' declaration", (DeclNameRef))
2925+
ERROR(differentiable_attr_class_derivative_not_final,none,
2926+
"class member derivative must be final", ())
2927+
ERROR(differentiable_attr_invalid_access,none,
2928+
"derivative function %0 is required to either be public or "
2929+
"'@usableFromInline' because the original function %1 is public or "
2930+
"'@usableFromInline'", (DeclNameRef, DeclName))
2931+
ERROR(differentiable_attr_result_not_differentiable,none,
2932+
"can only differentiate functions with results that conform to "
2933+
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
2934+
ERROR(differentiable_attr_protocol_req_where_clause,none,
2935+
"'@differentiable' attribute on protocol requirement cannot specify "
2936+
"'where' clause", ())
2937+
ERROR(differentiable_attr_protocol_req_assoc_func,none,
2938+
"'@differentiable' attribute on protocol requirement cannot specify "
2939+
"'jvp:' or 'vjp:'", ())
2940+
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
2941+
"'@differentiable' attribute on stored property cannot specify "
2942+
"'jvp:' or 'vjp:'", ())
2943+
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
2944+
"'@differentiable' attribute cannot be declared on class methods "
2945+
"returning 'Self'", ())
2946+
// TODO(TF-654): Remove when differentiation supports class initializers.
2947+
ERROR(differentiable_attr_class_init_not_yet_supported,none,
2948+
"'@differentiable' attribute does not yet support class initializers",
2949+
())
2950+
ERROR(differentiable_attr_empty_where_clause,none,
2951+
"empty 'where' clause in '@differentiable' attribute", ())
2952+
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
2953+
"'where' clause is valid only when original function is generic %0",
2954+
(DeclName))
2955+
ERROR(differentiable_attr_layout_req_unsupported,none,
2956+
"'@differentiable' attribute does not yet support layout requirements",
2957+
())
2958+
ERROR(overriding_decl_missing_differentiable_attr,none,
2959+
"overriding declaration is missing attribute '%0'", (StringRef))
2960+
NOTE(protocol_witness_missing_differentiable_attr,none,
2961+
"candidate is missing attribute '%0'", (StringRef))
2962+
29072963
// @derivative
29082964
ERROR(derivative_attr_expected_result_tuple,none,
29092965
"'@derivative(of:)' attribute requires function to return a two-element "

0 commit comments

Comments
 (0)