Skip to content

Commit e15142c

Browse files
authored
[AD] Support @differentiable where clauses with dependent member types. (#21731)
Previously, where clauses with dependent member type requirements crashed during serialization. This was because where clause requirements were manually computed. The solution is to use requirements computed by the type checker. - Improve comments about where clause type-checking. There are two goals: - Compute the required generic signature for autodiff associated functions, which is formed based on the original function's generic signature and the attribute's where clause requirements. - Store resolved where clause requirements in the attribute for serialization. - Add exercising tests. This completes `@differentiable` attribute where clause type-checking.
1 parent 9f97cfc commit e15142c

File tree

4 files changed

+59
-83
lines changed

4 files changed

+59
-83
lines changed

include/swift/AST/Attr.h

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,8 +1316,7 @@ class ClangImporterSynthesizedTypeAttr : public DeclAttribute {
13161316
/// @differentiable(reverse, wrt: (self, .0, .1), adjoint: bar(_:_:_:seed:))
13171317
class DifferentiableAttr final
13181318
: public DeclAttribute,
1319-
private llvm::TrailingObjects<DifferentiableAttr, AutoDiffParameter,
1320-
Requirement> {
1319+
private llvm::TrailingObjects<DifferentiableAttr, AutoDiffParameter> {
13211320
public:
13221321
struct DeclNameWithLoc {
13231322
DeclName Name;
@@ -1350,21 +1349,25 @@ class DifferentiableAttr final
13501349
FuncDecl *VJPFunction = nullptr;
13511350
/// Checked parameter indices, to be resolved by the type checker.
13521351
AutoDiffParameterIndices *CheckedParameterIndices = nullptr;
1353-
/// The constraint clauses for generic types.
1352+
/// The trailing where clause, if it exists.
13541353
TrailingWhereClause *WhereClause = nullptr;
1355-
/// The number of requirements in the trailing where clause, to be resolved
1356-
/// by the type checker.
1357-
unsigned NumRequirements = 0;
1358-
1359-
explicit DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
1354+
/// The requirements for autodiff associated functions. Resolved by the type
1355+
/// checker based on the original function's generic signature and the
1356+
/// attribute's where clause requirements. This is set only if the attribute's
1357+
/// where clause exists.
1358+
MutableArrayRef<Requirement> Requirements;
1359+
1360+
explicit DifferentiableAttr(ASTContext &context, SourceLoc atLoc,
1361+
SourceRange baseRange,
13601362
ArrayRef<AutoDiffParameter> parameters,
13611363
Optional<DeclNameWithLoc> primal,
13621364
Optional<DeclNameWithLoc> adjoint,
13631365
Optional<DeclNameWithLoc> jvp,
13641366
Optional<DeclNameWithLoc> vjp,
13651367
TrailingWhereClause *clause);
13661368

1367-
explicit DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
1369+
explicit DifferentiableAttr(ASTContext &context, SourceLoc atLoc,
1370+
SourceRange baseRange,
13681371
ArrayRef<AutoDiffParameter> parameters,
13691372
Optional<DeclNameWithLoc> primal,
13701373
Optional<DeclNameWithLoc> adjoint,
@@ -1417,16 +1420,9 @@ class DifferentiableAttr final
14171420

14181421
TrailingWhereClause *getWhereClause() const { return WhereClause; }
14191422

1420-
ArrayRef<Requirement> getRequirements() const {
1421-
return { getTrailingObjects<Requirement>(), NumRequirements };
1422-
}
1423-
MutableArrayRef<Requirement> getRequirements() {
1424-
return { getTrailingObjects<Requirement>(), NumRequirements };
1425-
}
1426-
void setRequirements(ASTContext &ctx, ArrayRef<Requirement> requirements);
1427-
size_t numTrailingObjects(OverloadToken<Requirement>) const {
1428-
return NumRequirements;
1429-
}
1423+
ArrayRef<Requirement> getRequirements() const { return Requirements; }
1424+
MutableArrayRef<Requirement> getRequirements() { return Requirements; }
1425+
void setRequirements(ASTContext &context, ArrayRef<Requirement> requirements);
14301426

14311427
FuncDecl *getPrimalFunction() const { return PrimalFunction; }
14321428
void setPrimalFunction(FuncDecl *decl) { PrimalFunction = decl; }

lib/AST/Attr.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,8 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,
10441044
}
10451045

10461046
// SWIFT_ENABLE_TENSORFLOW
1047-
DifferentiableAttr::DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
1047+
DifferentiableAttr::DifferentiableAttr(ASTContext &context, SourceLoc atLoc,
1048+
SourceRange baseRange,
10481049
ArrayRef<AutoDiffParameter> parameters,
10491050
Optional<DeclNameWithLoc> primal,
10501051
Optional<DeclNameWithLoc> adjoint,
@@ -1054,13 +1055,13 @@ DifferentiableAttr::DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
10541055
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, /*Implicit*/false),
10551056
NumParameters(parameters.size()),
10561057
Primal(std::move(primal)), Adjoint(std::move(adjoint)),
1057-
JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause),
1058-
NumRequirements(clause ? clause->getRequirements().size() : 0) {
1058+
JVP(std::move(jvp)), VJP(std::move(vjp)), WhereClause(clause) {
10591059
std::copy(parameters.begin(), parameters.end(),
10601060
getTrailingObjects<AutoDiffParameter>());
10611061
}
10621062

1063-
DifferentiableAttr::DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
1063+
DifferentiableAttr::DifferentiableAttr(ASTContext &context, SourceLoc atLoc,
1064+
SourceRange baseRange,
10641065
ArrayRef<AutoDiffParameter> parameters,
10651066
Optional<DeclNameWithLoc> primal,
10661067
Optional<DeclNameWithLoc> adjoint,
@@ -1070,12 +1071,10 @@ DifferentiableAttr::DifferentiableAttr(SourceLoc atLoc, SourceRange baseRange,
10701071
: DeclAttribute(DAK_Differentiable, atLoc, baseRange, /*Implicit*/false),
10711072
NumParameters(parameters.size()),
10721073
Primal(std::move(primal)), Adjoint(std::move(adjoint)),
1073-
JVP(std::move(jvp)), VJP(std::move(vjp)),
1074-
NumRequirements(requirements.size()) {
1074+
JVP(std::move(jvp)), VJP(std::move(vjp)) {
10751075
std::copy(parameters.begin(), parameters.end(),
10761076
getTrailingObjects<AutoDiffParameter>());
1077-
std::copy(requirements.begin(), requirements.end(),
1078-
getTrailingObjects<Requirement>());
1077+
setRequirements(context, requirements);
10791078
}
10801079

10811080
DifferentiableAttr *
@@ -1087,10 +1086,9 @@ DifferentiableAttr::create(ASTContext &context, SourceLoc atLoc,
10871086
Optional<DeclNameWithLoc> jvp,
10881087
Optional<DeclNameWithLoc> vjp,
10891088
TrailingWhereClause *clause) {
1090-
unsigned size = totalSizeToAlloc<AutoDiffParameter, Requirement>(
1091-
parameters.size(), clause ? clause->getRequirements().size() : 0);
1089+
unsigned size = totalSizeToAlloc<AutoDiffParameter>(parameters.size());
10921090
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
1093-
return new (mem) DifferentiableAttr(atLoc, baseRange, parameters,
1091+
return new (mem) DifferentiableAttr(context, atLoc, baseRange, parameters,
10941092
std::move(primal), std::move(adjoint),
10951093
std::move(jvp), std::move(vjp), clause);
10961094
}
@@ -1104,22 +1102,19 @@ DifferentiableAttr::create(ASTContext &context, SourceLoc atLoc,
11041102
Optional<DeclNameWithLoc> jvp,
11051103
Optional<DeclNameWithLoc> vjp,
11061104
ArrayRef<Requirement> requirements) {
1107-
unsigned size = totalSizeToAlloc<AutoDiffParameter, Requirement>(
1108-
parameters.size(), requirements.size());
1105+
unsigned size = totalSizeToAlloc<AutoDiffParameter>(parameters.size());
11091106
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
1110-
return new (mem) DifferentiableAttr(atLoc, baseRange, parameters,
1107+
return new (mem) DifferentiableAttr(context, atLoc, baseRange, parameters,
11111108
std::move(primal), std::move(adjoint),
11121109
std::move(jvp), std::move(vjp),
11131110
requirements);
11141111
}
11151112

1116-
void DifferentiableAttr::setRequirements(ASTContext &ctx,
1113+
void DifferentiableAttr::setRequirements(ASTContext &context,
11171114
ArrayRef<Requirement> requirements) {
1118-
assert(requirements.size() <= NumRequirements &&
1119-
"Requirements size must not exceed number of allocated requirements");
1120-
NumRequirements = requirements.size();
1121-
std::copy(requirements.begin(), requirements.end(),
1122-
getTrailingObjects<Requirement>());
1115+
Requirements =
1116+
context.AllocateUninitialized<Requirement>(requirements.size());
1117+
std::copy(requirements.begin(), requirements.end(), Requirements.data());
11231118
}
11241119

11251120
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,21 +2289,27 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
22892289
return;
22902290
}
22912291

2292-
// Type-check 'where' clause.
2292+
// Handle 'where' clause, if it exists.
2293+
// - Resolve attribute where clause requirements and store in the attribute
2294+
// for serialization.
2295+
// - Compute generic signature for autodiff associated functions based on
2296+
// the original function's generate signature and the attribute's where
2297+
// clause requirements.
22932298
GenericSignature *whereClauseGenSig = nullptr;
22942299
GenericEnvironment *whereClauseGenEnv = nullptr;
22952300
if (auto whereClause = attr->getWhereClause()) {
22962301
if (whereClause->getRequirements().empty()) {
2297-
// Report an empty where clause.
2302+
// Where clause must not be empty.
22982303
TC.diagnose(attr->getLocation(),
22992304
diag::differentiable_attr_empty_where_clause);
23002305
attr->setInvalid();
23012306
return;
23022307
}
23032308

2304-
auto *genericSig = original->getGenericSignature();
2305-
if (!genericSig) {
2306-
// Only generic functions can have trailing where clauses.
2309+
auto *originalGenSig = original->getGenericSignature();
2310+
if (!originalGenSig) {
2311+
// Attributes with where clauses can only be declared on
2312+
// generic functions.
23072313
TC.diagnose(attr->getLocation(),
23082314
diag::differentiable_attr_nongeneric_trailing_where,
23092315
original->getFullName())
@@ -2312,34 +2318,16 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23122318
return;
23132319
}
23142320

2315-
// Form a new generic signature.
2321+
// Build a new generic signature for autodiff associated functions.
23162322
GenericSignatureBuilder builder(ctx);
2317-
// First, add the old generic signature.
2318-
builder.addGenericSignature(genericSig);
2319-
// Go over the set of requirements, adding them to the builder.
2320-
SmallVector<Requirement, 4> convertedRequirements;
2321-
2322-
// Set where clause owner.
2323-
// Default to using @differentiable attribute as owner.
2324-
// If original function is declared on a protocol, use protocol's generic
2325-
// parameters instead.
2326-
WhereClauseOwner owner(original, attr);
2327-
if (auto nominal = original->getDeclContext()->getSelfNominalTypeDecl()) {
2328-
if (auto proto = dyn_cast<ProtocolDecl>(nominal)) {
2329-
auto DC = original->getDeclContext();
2330-
auto genericParams = proto->createGenericParams(DC);
2331-
TC.prepareGenericParamList(genericParams, DC);
2332-
genericParams->addTrailingWhereClause(ctx, whereClause->getWhereLoc(),
2333-
whereClause->getRequirements());
2334-
owner = WhereClauseOwner(original, genericParams);
2335-
}
2336-
}
2323+
// Add the original function's generic signature.
2324+
builder.addGenericSignature(originalGenSig);
23372325

23382326
using FloatingRequirementSource =
2339-
GenericSignatureBuilder::FloatingRequirementSource;
2327+
GenericSignatureBuilder::FloatingRequirementSource;
23402328

23412329
RequirementRequest::visitRequirements(
2342-
owner, TypeResolutionStage::Structural,
2330+
WhereClauseOwner(original, attr), TypeResolutionStage::Structural,
23432331
[&](const Requirement &req, RequirementRepr *reqRepr) {
23442332
// Check additional constraints.
23452333
// TODO: refine constraints.
@@ -2355,7 +2343,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23552343
.highlight(reqRepr->getSourceRange());
23562344
return false;
23572345

2358-
// Conformance requirements are supported if:
2346+
// Conformance requirements are valid if:
23592347
// - The first type is a generic type parameter type.
23602348
// - The second type is a protocol type.
23612349
case RequirementKind::Conformance:
@@ -2373,20 +2361,20 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23732361
break;
23742362
}
23752363

2376-
// Add the requirement to the generic signature builder.
2364+
// Add requirement to generic signature builder.
23772365
builder.addRequirement(req, reqRepr,
23782366
FloatingRequirementSource::forExplicit(reqRepr),
23792367
nullptr, original->getModuleContext());
2380-
convertedRequirements.push_back(getCanonicalRequirement(req));
23812368
return false;
23822369
});
23832370

2384-
// Store the converted requirements in the attribute.
2385-
attr->setRequirements(ctx, convertedRequirements);
2371+
// Compute generic signature and environment for autodiff associated
2372+
// functions.
23862373
whereClauseGenSig = std::move(builder).computeGenericSignature(
23872374
attr->getLocation(), /*allowConcreteGenericParams=*/true);
23882375
whereClauseGenEnv = whereClauseGenSig->createGenericEnvironment();
2389-
whereClauseGenEnv->setOwningDeclContext(original->getDeclContext());
2376+
// Store the resolved requirements in the attribute.
2377+
attr->setRequirements(ctx, whereClauseGenSig->getRequirements());
23902378
}
23912379

23922380
// Resolve the primal declaration, if it exists.

test/Serialization/differentiable_attr.swift

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func vjpSimpleVJP(x: Float) -> (Float, (Float) -> Float) {
9898
return (x, { v in v })
9999
}
100100

101-
// CHECK-DAG: @differentiable(vjp: vjpTestWhereClause where T : Differentiable)
101+
// CHECK-DAG: @differentiable(vjp: vjpTestWhereClause where T : Differentiable, T : Numeric)
102102
// CHECK-DAG: func testWhereClause<T>(x: T) -> T where T : Numeric
103103
@differentiable(vjp: vjpTestWhereClause where T : Differentiable)
104104
func testWhereClause<T : Numeric>(x: T) -> T {
@@ -112,7 +112,7 @@ func vjpTestWhereClause<T>(x: T) -> (T, (T.CotangentVector) -> T.CotangentVector
112112

113113
protocol P {}
114114
extension P {
115-
// CHECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClause where Self : Differentiable)
115+
// CHECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClause where Self : Differentiable, Self : P)
116116
// CHECK-DAG: func testWhereClause() -> Self
117117
@differentiable(wrt: (self), vjp: vjpTestWhereClause where Self : Differentiable)
118118
func testWhereClause() -> Self {
@@ -125,12 +125,11 @@ extension P where Self : Differentiable {
125125
}
126126
}
127127

128-
/*
129128
// NOTE: The failing tests involve where clauses with member type constraints.
130129
// They pass type-checking but crash during serialization.
131130

132-
// HECK-DAG: @differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T == T.CotangentVector)
133-
// HECK-DAG: func testWhereClauseMemberTypeConstraint<T : Numeric>(x: T) -> T {
131+
// CHECK-DAG: @differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T : Numeric, T == T.CotangentVector)
132+
// CHECK-DAG: func testWhereClauseMemberTypeConstraint<T>(x: T) -> T where T : Numeric
134133
@differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T == T.CotangentVector)
135134
func testWhereClauseMemberTypeConstraint<T : Numeric>(x: T) -> T {
136135
return x
@@ -141,18 +140,16 @@ func vjpTestWhereClauseMemberTypeConstraint<T>(x: T) -> (T, (T) -> T)
141140
return (x, { v in v })
142141
}
143142

144-
protocol P {}
145143
extension P {
146-
// HECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClauseMemberTypeConstraint where Self.CotangentVector == Self, Self : Differentiable)
147-
// HECK-DAG: func testWhereClauseMemberTypeConstraint() -> Self {
144+
// CHECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClauseMemberTypeConstraint where Self : Differentiable, Self : P, Self == Self.CotangentVector)
145+
// CHECK-DAG: func testWhereClauseMemberTypeConstraint() -> Self
148146
@differentiable(wrt: (self), vjp: vjpTestWhereClauseMemberTypeConstraint where Self.CotangentVector == Self, Self : Differentiable)
149147
func testWhereClauseMemberTypeConstraint() -> Self {
150148
return self
151149
}
152150
}
153-
extension P where Self : Differentiable {
151+
extension P where Self : Differentiable, Self == Self.CotangentVector {
154152
func vjpTestWhereClauseMemberTypeConstraint() -> (Self, (Self.CotangentVector) -> Self.CotangentVector) {
155153
return (self, { v in v })
156154
}
157155
}
158-
*/

0 commit comments

Comments
 (0)