Skip to content

Commit fc4c701

Browse files
rxweidan-zheng
authored andcommitted
[AutoDiff] [stdlib] Derive conformances for 'EuclideanDifferentiable'. (#26867)
Add derived conformances for the `EuclideanDifferentiable` protocol introduced in #26287 when a conforming type satisfies `Differentiable` conformance synthesis requirements. ```swift struct Foo<T>: EuclideanDifferentiable where T.TangentVector == T { var x: T var y: T @noDerivative var z: Bool // The compiler synthesizes the following `EuclideanDifferentiable` requirement: // var vectorView: TangentVector { // return TangentVector(x: x, y: y) // } } ``` Remove `vectorView`'s setter. This should not have been added to the protocol, for the same reason as [TF-208](https://bugs.swift.org/browse/TF-208). A projection (`vectorView`) of a subset of properties should reflect these properties joint mutability, which is impossible to express at the moment. Resolves [TF-777](https://bugs.swift.org/browse/TF-777).
1 parent 6bffe97 commit fc4c701

13 files changed

+281
-64
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,8 @@ ERROR(broken_vector_protocol_requirement,none,
25762576
"VectorProtocol protocol is broken: unexpected requirement", ())
25772577
ERROR(broken_differentiable_requirement,none,
25782578
"Differentiable protocol is broken: unexpected requirement", ())
2579+
ERROR(broken_euclidean_differentiable_requirement,none,
2580+
"EuclideanDifferentiable protocol is broken: unexpected requirement", ())
25792581
ERROR(broken_key_path_iterable_requirement,none,
25802582
"KeyPathIterable protocol is broken: unexpected requirement", ())
25812583
ERROR(broken_tensor_array_protocol_requirement,none,

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ IDENTIFIER(x)
157157
// Differentiable
158158
IDENTIFIER(TangentVector)
159159
IDENTIFIER(move)
160+
IDENTIFIER(vectorView)
160161

161162
// Kinds of layout constraints
162163
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ PROTOCOL_(TensorFlowDataTypeCompatible)
8787
PROTOCOL(TensorProtocol)
8888
PROTOCOL(VectorProtocol)
8989
PROTOCOL(Differentiable)
90+
PROTOCOL(EuclideanDifferentiable)
9091

9192
PROTOCOL_(ObjectiveCBridgeable)
9293
PROTOCOL_(DestructorSafeContainer)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4223,6 +4223,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42234223
case KnownProtocolKind::TensorProtocol:
42244224
case KnownProtocolKind::VectorProtocol:
42254225
case KnownProtocolKind::Differentiable:
4226+
case KnownProtocolKind::EuclideanDifferentiable:
42264227
return SpecialProtocol::None;
42274228
}
42284229

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 155 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232

3333
using namespace swift;
3434

35-
// Return the protocol requirement with the specified name.
36-
// TODO: Move function to shared place for use with other derived conformances.
35+
/// Return the protocol requirement with the specified name.
36+
/// TODO: Move function to shared place for use with other derived conformances.
3737
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
3838
auto lookup = proto->lookupDirect(name);
3939
// Erase declarations that are not protocol requirements.
@@ -46,8 +46,8 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
4646
return lookup.front();
4747
}
4848

49-
// Get the stored properties of a nominal type that are relevant for
50-
// differentiation, except the ones tagged `@noDerivative`.
49+
/// Get the stored properties of a nominal type that are relevant for
50+
/// differentiation, except the ones tagged `@noDerivative`.
5151
static void
5252
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal,
5353
DeclContext *DC,
@@ -71,8 +71,8 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal,
7171
}
7272
}
7373

74-
// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
75-
// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
74+
/// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
75+
/// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
7676
static StructDecl *convertToStructDecl(ValueDecl *v) {
7777
if (auto *structDecl = dyn_cast<StructDecl>(v))
7878
return structDecl;
@@ -83,10 +83,10 @@ static StructDecl *convertToStructDecl(ValueDecl *v) {
8383
typeDecl->getDeclaredInterfaceType()->getAnyNominal());
8484
}
8585

86-
// Get the `Differentiable` protocol `TangentVector` associated type for the
87-
// given `VarDecl`.
88-
// TODO: Generalize and move function to shared place for use with other derived
89-
// conformances.
86+
/// Get the `Differentiable` protocol `TangentVector` associated type for the
87+
/// given `VarDecl`.
88+
/// TODO: Generalize and move function to shared place for use with other derived
89+
/// conformances.
9090
static Type getTangentVectorType(VarDecl *decl, DeclContext *DC) {
9191
auto &C = decl->getASTContext();
9292
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
@@ -196,7 +196,34 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
196196
});
197197
}
198198

199-
// Synthesize body for a `Differentiable` method requirement.
199+
/// Determine if a EuclideanDifferentiable requirement can be derived for a type.
200+
///
201+
/// \returns True if the requirement can be derived.
202+
bool DerivedConformance::canDeriveEuclideanDifferentiable(
203+
NominalTypeDecl *nominal, DeclContext *DC) {
204+
if (!canDeriveDifferentiable(nominal, DC))
205+
return false;
206+
auto &C = nominal->getASTContext();
207+
auto *lazyResolver = C.getLazyResolver();
208+
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
209+
// Return true if all differentiation stored properties conform to
210+
// `AdditiveArithmetic` and their `TangentVector` equals themselves.
211+
SmallVector<VarDecl *, 16> diffProperties;
212+
getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
213+
return llvm::all_of(diffProperties, [&](VarDecl *member) {
214+
if (!member->hasInterfaceType())
215+
lazyResolver->resolveDeclSignature(member);
216+
if (!member->hasInterfaceType())
217+
return false;
218+
auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
219+
if (!TypeChecker::conformsToProtocol(varType, addArithProto, DC, None))
220+
return false;
221+
auto memberAssocType = getTangentVectorType(member, DC);
222+
return member->getType()->isEqual(memberAssocType);
223+
});
224+
}
225+
226+
/// Synthesize body for a `Differentiable` method requirement.
200227
static std::pair<BraceStmt *, bool>
201228
deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
202229
Identifier methodName,
@@ -283,15 +310,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
283310
return std::pair<BraceStmt *, bool>(braceStmt, false);
284311
}
285312

286-
// Synthesize body for `move(along:)`.
313+
/// Synthesize body for `move(along:)`.
287314
static std::pair<BraceStmt *, bool>
288315
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
289316
auto &C = funcDecl->getASTContext();
290317
return deriveBodyDifferentiable_method(funcDecl, C.Id_move,
291318
C.getIdentifier("along"));
292319
}
293320

294-
// Synthesize function declaration for a `Differentiable` method requirement.
321+
/// Synthesize function declaration for a `Differentiable` method requirement.
295322
static ValueDecl *deriveDifferentiable_method(
296323
DerivedConformance &derived, Identifier methodName, Identifier argumentName,
297324
Identifier parameterName, Type parameterType, Type returnType,
@@ -329,7 +356,7 @@ static ValueDecl *deriveDifferentiable_method(
329356
return funcDecl;
330357
}
331358

332-
// Synthesize the `move(along:)` function declaration.
359+
/// Synthesize the `move(along:)` function declaration.
333360
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
334361
auto &C = derived.TC.Context;
335362
auto *parentDC = derived.getConformanceContext();
@@ -343,8 +370,92 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
343370
{deriveBodyDifferentiable_move, nullptr});
344371
}
345372

346-
// Return associated `TangentVector` struct for a nominal type, if it exists.
347-
// If not, synthesize the struct.
373+
/// Synthesize the `vectorView` property declaration.
374+
static ValueDecl *deriveEuclideanDifferentiable_vectorView(
375+
DerivedConformance &derived) {
376+
auto &C = derived.TC.Context;
377+
auto *parentDC = derived.getConformanceContext();
378+
379+
auto *tangentDecl = getTangentVectorStructDecl(parentDC);
380+
auto tangentType = tangentDecl->getDeclaredInterfaceType();
381+
auto tangentContextualType = parentDC->mapTypeIntoContext(tangentType);
382+
383+
VarDecl *vectorViewDecl;
384+
PatternBindingDecl *pbDecl;
385+
std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
386+
C.Id_vectorView, tangentType, tangentContextualType, /*isStatic*/ false,
387+
/*isFinal*/ true);
388+
389+
struct GetterSynthesizerContext {
390+
StructDecl *tangentDecl;
391+
Type tangentContextualType;
392+
};
393+
394+
auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx)
395+
-> std::pair<BraceStmt *, bool> {
396+
auto *context = reinterpret_cast<GetterSynthesizerContext *>(ctx);
397+
assert(context && "Invalid context");
398+
auto *parentDC = getterDecl->getParent();
399+
auto *nominal = parentDC->getSelfNominalTypeDecl();
400+
auto &C = nominal->getASTContext();
401+
SmallVector<VarDecl *, 8> diffProperties;
402+
getStoredPropertiesForDifferentiation(nominal, nominal->getDeclContext(),
403+
diffProperties);
404+
405+
// Create a reference to the memberwise initializer: `TangentVector.init`.
406+
auto *memberwiseInitDecl =
407+
context->tangentDecl->getEffectiveMemberwiseInitializer();
408+
assert(memberwiseInitDecl && "Memberwise initializer must exist");
409+
// `TangentVector`
410+
auto *tangentTypeExpr =
411+
TypeExpr::createImplicit(context->tangentContextualType, C);
412+
// `TangentVector.init`
413+
auto *initDRE = new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(),
414+
/*Implicit*/ true);
415+
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
416+
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr);
417+
initExpr->setThrows(false);
418+
initExpr->setImplicit();
419+
420+
// Create a call:
421+
// TangentVector.init(
422+
// <property_name_1...>: self.<property_name_1>,
423+
// <property_name_2...>: self.<property_name_2>,
424+
// ...
425+
// )
426+
SmallVector<Identifier, 8> argLabels;
427+
SmallVector<Expr *, 8> memberRefs;
428+
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
429+
DeclNameLoc(),
430+
/*Implicit*/ true);
431+
for (auto *member : diffProperties) {
432+
argLabels.push_back(member->getName());
433+
memberRefs.push_back(
434+
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
435+
/*Implicit*/ true));
436+
}
437+
assert(memberRefs.size() == argLabels.size());
438+
CallExpr *callExpr =
439+
CallExpr::createImplicit(C, initExpr, memberRefs, argLabels);
440+
441+
// Create a return statement: `return TangentVector.init(...)`.
442+
ASTNode retStmt =
443+
new (C) ReturnStmt(SourceLoc(), callExpr, /*implicit*/ true);
444+
auto *braceStmt = BraceStmt::create(C, SourceLoc(), retStmt, SourceLoc(),
445+
/*implicit*/ true);
446+
return std::make_pair(braceStmt, false);
447+
};
448+
auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
449+
vectorViewDecl, tangentContextualType);
450+
getterDecl->setBodySynthesizer(
451+
getterSynthesizer, /*context*/ C.AllocateObjectCopy(
452+
GetterSynthesizerContext{tangentDecl, tangentContextualType}));
453+
derived.addMembersToConformanceContext({vectorViewDecl, pbDecl});
454+
return vectorViewDecl;
455+
}
456+
457+
/// Return associated `TangentVector` struct for a nominal type, if it exists.
458+
/// If not, synthesize the struct.
348459
static StructDecl *
349460
getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
350461
auto &TC = derived.TC;
@@ -362,8 +473,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
362473
return structDecl;
363474
}
364475

365-
// Otherwise, synthesize a new struct. The struct must conform to
366-
// `Differentiable`.
476+
// Otherwise, synthesize a new struct.
367477
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
368478
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
369479
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
@@ -378,9 +488,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
378488
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
379489
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());
380490

381-
SmallVector<TypeLoc, 4> inherited{diffableType};
382-
// `TangentVector` must conform to `AdditiveArithmetic`.
383-
inherited.push_back(addArithType);
491+
// By definition, `TangentVector` must conform to `Differentiable` and
492+
// `AdditiveArithmetic`.
493+
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};
384494

385495
// Cache original members and their associated types for later use.
386496
SmallVector<VarDecl *, 8> diffProperties;
@@ -551,8 +661,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
551661
return structDecl;
552662
}
553663

554-
// Add a typealias declaration with the given name and underlying target
555-
// struct type to the given source nominal declaration context.
664+
/// Add a typealias declaration with the given name and underlying target
665+
/// struct type to the given source nominal declaration context.
556666
static void addAssociatedTypeAliasDecl(Identifier name,
557667
DeclContext *sourceDC,
558668
StructDecl *target,
@@ -585,11 +695,11 @@ static void addAssociatedTypeAliasDecl(Identifier name,
585695
C.addSynthesizedDecl(aliasDecl);
586696
};
587697

588-
// Diagnose stored properties in the nominal that do not have an explicit
589-
// `@noDerivative` attribute, but either:
590-
// - Do not conform to `Differentiable`.
591-
// - Are a `let` stored property.
592-
// Emit a warning and a fixit so that users will make the attribute explicit.
698+
/// Diagnose stored properties in the nominal that do not have an explicit
699+
/// `@noDerivative` attribute, but either:
700+
/// - Do not conform to `Differentiable`.
701+
/// - Are a `let` stored property.
702+
/// Emit a warning and a fixit so that users will make the attribute explicit.
593703
static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
594704
NominalTypeDecl *nominal,
595705
DeclContext* DC) {
@@ -637,7 +747,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
637747
}
638748
}
639749

640-
// Get or synthesize `TangentVector` struct type.
750+
/// Get or synthesize `TangentVector` struct type.
641751
static Type
642752
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
643753
auto &TC = derived.TC;
@@ -669,7 +779,7 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
669779
tangentStruct->getDeclaredInterfaceType());
670780
}
671781

672-
// Synthesize the `TangentVector` struct type.
782+
/// Synthesize the `TangentVector` struct type.
673783
static Type
674784
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
675785
auto &TC = derived.TC;
@@ -756,3 +866,18 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
756866
TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement);
757867
return nullptr;
758868
}
869+
870+
/// Derive a EuclideanDifferentiable requirement for a nominal type.
871+
///
872+
/// \returns the derived member, which will also be added to the type.
873+
ValueDecl *DerivedConformance::deriveEuclideanDifferentiable(
874+
ValueDecl *requirement) {
875+
// Diagnose conformances in disallowed contexts.
876+
if (checkAndDiagnoseDisallowedContext(requirement))
877+
return nullptr;
878+
if (requirement->getFullName() == TC.Context.Id_vectorView)
879+
return deriveEuclideanDifferentiable_vectorView(*this);
880+
TC.diagnose(requirement->getLoc(),
881+
diag::broken_euclidean_differentiable_requirement);
882+
return nullptr;
883+
}

lib/Sema/DerivedConformances.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
9494
if (*knownProtocol == KnownProtocolKind::Differentiable)
9595
return canDeriveDifferentiable(Nominal, DC);
9696

97+
// SWIFT_ENABLE_TENSORFLOW
98+
if (*knownProtocol == KnownProtocolKind::EuclideanDifferentiable)
99+
return canDeriveEuclideanDifferentiable(Nominal, DC);
100+
97101
if (auto *enumDecl = dyn_cast<EnumDecl>(Nominal)) {
98102
switch (*knownProtocol) {
99103
// The presence of a raw type is an explicit declaration that
@@ -255,6 +259,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
255259
if (name.isSimpleName(ctx.Id_zero))
256260
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
257261

262+
// SWIFT_ENABLE_TENSORFLOW
263+
// EuclideanDifferentiable.vectorView
264+
if (name.isSimpleName(ctx.Id_vectorView))
265+
return getRequirement(KnownProtocolKind::EuclideanDifferentiable);
266+
258267
// SWIFT_ENABLE_TENSORFLOW
259268
// PointwiseMultiplicative.one
260269
if (name.isSimpleName(ctx.Id_one))

lib/Sema/DerivedConformances.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,17 @@ class DerivedConformance {
292292
/// \returns the derived member, which will also be added to the type.
293293
ValueDecl *deriveDifferentiable(ValueDecl *requirement);
294294

295+
/// Determine if a Differentiable requirement can be derived for a type.
296+
///
297+
/// \returns True if the requirement can be derived.
298+
static bool canDeriveEuclideanDifferentiable(NominalTypeDecl *type,
299+
DeclContext *DC);
300+
301+
/// Derive a EuclideanDifferentiable requirement for a nominal type.
302+
///
303+
/// \returns the derived member, which will also be added to the type.
304+
ValueDecl *deriveEuclideanDifferentiable(ValueDecl *requirement);
305+
295306
/// Derive a Differentiable type witness for a nominal type.
296307
///
297308
/// \returns the derived member, which will also be added to the type.

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5363,6 +5363,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
53635363
case KnownProtocolKind::Differentiable:
53645364
return derived.deriveDifferentiable(Requirement);
53655365

5366+
// SWIFT_ENABLE_TENSORFLOW
5367+
case KnownProtocolKind::EuclideanDifferentiable:
5368+
return derived.deriveEuclideanDifferentiable(Requirement);
5369+
53665370
default:
53675371
return nullptr;
53685372
}

stdlib/public/core/AutoDiff.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,11 @@ public extension Differentiable where TangentVector == Self {
229229
/// `TangentVector` is equal to its vector space component.
230230
public protocol EuclideanDifferentiable: Differentiable {
231231
/// The differentiable vector component of `self`.
232-
var vectorView: TangentVector { get set }
232+
var vectorView: TangentVector { get }
233233
}
234234

235235
public extension EuclideanDifferentiable where TangentVector == Self {
236-
var vectorView: TangentVector {
237-
_read { yield self }
238-
_modify { yield &self }
239-
}
236+
var vectorView: TangentVector { _read { yield self } }
240237
}
241238

242239
/// Returns `x` like an identity function. When used in a context where `x` is

0 commit comments

Comments
 (0)