Skip to content

Commit f4cf3d9

Browse files
committed
[AutoDiff] [stdlib] Derive conformances for 'EuclideanDifferentiable'.
* Add derived conformances for the `EuclideanDifferentiable` protocol introduced in #26287 when a conforming type satisfies `Differentiable` conformance synthesis requirements. ```swift struct Foo<T>: EuclideanDifferentiable where T.TangentVector == T { var x: T var y: T @noDerivative var z: Bool // The compiler synthesizes the following `EuclideanDifferentiable` requirement: // var vectorView: TangentVector { // return TangentVector(x: x, y: y) // } } ``` * Remove `vectorView`'s setter. This should not have been added to the protocol, for the same reason as [TF-208](https://bugs.swift.org/browse/TF-208). A projection (`vectorView`) of a subset of properties should reflect these properties joint mutability, which is impossible to express at the moment. Resolves [TF-777](https://bugs.swift.org/browse/TF-777).
1 parent aa4fd59 commit f4cf3d9

13 files changed

+241
-44
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,8 @@ ERROR(broken_vector_protocol_requirement,none,
25762576
"VectorProtocol protocol is broken: unexpected requirement", ())
25772577
ERROR(broken_differentiable_requirement,none,
25782578
"Differentiable protocol is broken: unexpected requirement", ())
2579+
ERROR(broken_euclidean_differentiable_requirement,none,
2580+
"EuclideanDifferentiable protocol is broken: unexpected requirement", ())
25792581
ERROR(broken_key_path_iterable_requirement,none,
25802582
"KeyPathIterable protocol is broken: unexpected requirement", ())
25812583
ERROR(broken_tensor_array_protocol_requirement,none,

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ IDENTIFIER(x)
157157
// Differentiable
158158
IDENTIFIER(TangentVector)
159159
IDENTIFIER(move)
160+
IDENTIFIER(vectorView)
160161

161162
// Kinds of layout constraints
162163
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ PROTOCOL_(TensorFlowDataTypeCompatible)
8787
PROTOCOL(TensorProtocol)
8888
PROTOCOL(VectorProtocol)
8989
PROTOCOL(Differentiable)
90+
PROTOCOL(EuclideanDifferentiable)
9091

9192
PROTOCOL_(ObjectiveCBridgeable)
9293
PROTOCOL_(DestructorSafeContainer)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4223,6 +4223,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42234223
case KnownProtocolKind::TensorProtocol:
42244224
case KnownProtocolKind::VectorProtocol:
42254225
case KnownProtocolKind::Differentiable:
4226+
case KnownProtocolKind::EuclideanDifferentiable:
42264227
return SpecialProtocol::None;
42274228
}
42284229

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,33 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
196196
});
197197
}
198198

199+
/// Determine if a EuclideanDifferentiable requirement can be derived for a type.
200+
///
201+
/// \returns True if the requirement can be derived.
202+
bool DerivedConformance::canDeriveEuclideanDifferentiable(
203+
NominalTypeDecl *nominal, DeclContext *DC) {
204+
if (!canDeriveDifferentiable(nominal, DC))
205+
return false;
206+
auto &C = nominal->getASTContext();
207+
auto *lazyResolver = C.getLazyResolver();
208+
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
209+
// Return true if all differentiation stored properties conform to
210+
// `AdditiveArithmetic` and their `TangentVector` equals themselves.
211+
SmallVector<VarDecl *, 16> diffProperties;
212+
getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
213+
return llvm::all_of(diffProperties, [&](VarDecl *member) {
214+
if (!member->hasInterfaceType())
215+
lazyResolver->resolveDeclSignature(member);
216+
if (!member->hasInterfaceType())
217+
return false;
218+
auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
219+
if (!TypeChecker::conformsToProtocol(varType, addArithProto, DC, None))
220+
return false;
221+
auto memberAssocType = getTangentVectorType(member, DC);
222+
return member->getType()->isEqual(memberAssocType);
223+
});
224+
}
225+
199226
// Synthesize body for a `Differentiable` method requirement.
200227
static std::pair<BraceStmt *, bool>
201228
deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
@@ -343,6 +370,90 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
343370
{deriveBodyDifferentiable_move, nullptr});
344371
}
345372

373+
// Synthesize the `vectorView` property declaration.
374+
static ValueDecl *deriveEuclideanDifferentiable_vectorView(
375+
DerivedConformance &derived) {
376+
auto &C = derived.TC.Context;
377+
auto *parentDC = derived.getConformanceContext();
378+
379+
auto *tangentDecl = getTangentVectorStructDecl(parentDC);
380+
auto tangentType = tangentDecl->getDeclaredInterfaceType();
381+
auto tangentContextualType = parentDC->mapTypeIntoContext(tangentType);
382+
383+
VarDecl *vectorViewDecl;
384+
PatternBindingDecl *pbDecl;
385+
std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
386+
C.Id_vectorView, tangentType, tangentContextualType, /*isStatic*/ false,
387+
/*isFinal*/ true);
388+
389+
struct GetterSynthesizerContext {
390+
StructDecl *tangentDecl;
391+
Type tangentContextualType;
392+
};
393+
394+
auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx)
395+
-> std::pair<BraceStmt *, bool> {
396+
auto *context = reinterpret_cast<GetterSynthesizerContext *>(ctx);
397+
assert(context && "Invalid context");
398+
auto *parentDC = getterDecl->getParent();
399+
auto *nominal = parentDC->getSelfNominalTypeDecl();
400+
auto &C = nominal->getASTContext();
401+
SmallVector<VarDecl *, 8> diffProperties;
402+
getStoredPropertiesForDifferentiation(nominal, nominal->getDeclContext(),
403+
diffProperties);
404+
405+
// Create a reference to the memberwise initializer: `TangentVector.init`.
406+
auto *memberwiseInitDecl =
407+
context->tangentDecl->getEffectiveMemberwiseInitializer();
408+
assert(memberwiseInitDecl && "Memberwise initializer must exist");
409+
// `TangentVector`
410+
auto *tangentTypeExpr =
411+
TypeExpr::createImplicit(context->tangentContextualType, C);
412+
// `TangentVector.init`
413+
auto *initDRE = new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(),
414+
/*Implicit*/ true);
415+
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
416+
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr);
417+
initExpr->setThrows(false);
418+
initExpr->setImplicit();
419+
420+
// Create a call:
421+
// TangentVector.init(
422+
// <property_name_1...>: self.<property_name_1>,
423+
// <property_name_2...>: self.<property_name_2>,
424+
// ...
425+
// )
426+
SmallVector<Identifier, 8> argLabels;
427+
SmallVector<Expr *, 8> memberRefs;
428+
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
429+
DeclNameLoc(),
430+
/*Implicit*/ true);
431+
for (auto *member : diffProperties) {
432+
argLabels.push_back(member->getName());
433+
memberRefs.push_back(
434+
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
435+
/*Implicit*/ true));
436+
}
437+
assert(memberRefs.size() == argLabels.size());
438+
CallExpr *callExpr =
439+
CallExpr::createImplicit(C, initExpr, memberRefs, argLabels);
440+
441+
// Create a return statement: `return TangentVector.init(...)`.
442+
ASTNode retStmt =
443+
new (C) ReturnStmt(SourceLoc(), callExpr, /*implicit*/ true);
444+
auto *braceStmt = BraceStmt::create(C, SourceLoc(), retStmt, SourceLoc(),
445+
/*implicit*/ true);
446+
return std::make_tuple(braceStmt, false);
447+
};
448+
auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
449+
vectorViewDecl, tangentContextualType);
450+
getterDecl->setBodySynthesizer(
451+
getterSynthesizer, /*context*/ C.AllocateObjectCopy(
452+
GetterSynthesizerContext{tangentDecl, tangentContextualType}));
453+
derived.addMembersToConformanceContext({vectorViewDecl, pbDecl});
454+
return vectorViewDecl;
455+
}
456+
346457
// Return associated `TangentVector` struct for a nominal type, if it exists.
347458
// If not, synthesize the struct.
348459
static StructDecl *
@@ -362,8 +473,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
362473
return structDecl;
363474
}
364475

365-
// Otherwise, synthesize a new struct. The struct must conform to
366-
// `Differentiable`.
476+
// Otherwise, synthesize a new struct.
367477
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
368478
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
369479
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
@@ -378,9 +488,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
378488
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
379489
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());
380490

381-
SmallVector<TypeLoc, 4> inherited{diffableType};
382-
// `TangentVector` must conform to `AdditiveArithmetic`.
383-
inherited.push_back(addArithType);
491+
// By definition, `TangentVector` must conform to `EuclideanDifferentiable`
492+
// and `AdditiveArithmetic`.
493+
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};
384494

385495
// Cache original members and their associated types for later use.
386496
SmallVector<VarDecl *, 8> diffProperties;
@@ -756,3 +866,18 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
756866
TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement);
757867
return nullptr;
758868
}
869+
870+
/// Derive a EuclideanDifferentiable requirement for a nominal type.
871+
///
872+
/// \returns the derived member, which will also be added to the type.
873+
ValueDecl *DerivedConformance::deriveEuclideanDifferentiable(
874+
ValueDecl *requirement) {
875+
// Diagnose conformances in disallowed contexts.
876+
if (checkAndDiagnoseDisallowedContext(requirement))
877+
return nullptr;
878+
if (requirement->getFullName() == TC.Context.Id_vectorView)
879+
return deriveEuclideanDifferentiable_vectorView(*this);
880+
TC.diagnose(requirement->getLoc(),
881+
diag::broken_euclidean_differentiable_requirement);
882+
return nullptr;
883+
}

lib/Sema/DerivedConformances.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
9494
if (*knownProtocol == KnownProtocolKind::Differentiable)
9595
return canDeriveDifferentiable(Nominal, DC);
9696

97+
// SWIFT_ENABLE_TENSORFLOW
98+
if (*knownProtocol == KnownProtocolKind::EuclideanDifferentiable)
99+
return canDeriveEuclideanDifferentiable(Nominal, DC);
100+
97101
if (auto *enumDecl = dyn_cast<EnumDecl>(Nominal)) {
98102
switch (*knownProtocol) {
99103
// The presence of a raw type is an explicit declaration that
@@ -255,6 +259,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
255259
if (name.isSimpleName(ctx.Id_zero))
256260
return getRequirement(KnownProtocolKind::AdditiveArithmetic);
257261

262+
// SWIFT_ENABLE_TENSORFLOW
263+
// EuclideanDifferentiable.vectorView
264+
if (name.isSimpleName(ctx.Id_vectorView))
265+
return getRequirement(KnownProtocolKind::EuclideanDifferentiable);
266+
258267
// SWIFT_ENABLE_TENSORFLOW
259268
// PointwiseMultiplicative.one
260269
if (name.isSimpleName(ctx.Id_one))

lib/Sema/DerivedConformances.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,17 @@ class DerivedConformance {
292292
/// \returns the derived member, which will also be added to the type.
293293
ValueDecl *deriveDifferentiable(ValueDecl *requirement);
294294

295+
/// Determine if a Differentiable requirement can be derived for a type.
296+
///
297+
/// \returns True if the requirement can be derived.
298+
static bool canDeriveEuclideanDifferentiable(NominalTypeDecl *type,
299+
DeclContext *DC);
300+
301+
/// Derive a EuclideanDifferentiable requirement for a nominal type.
302+
///
303+
/// \returns the derived member, which will also be added to the type.
304+
ValueDecl *deriveEuclideanDifferentiable(ValueDecl *requirement);
305+
295306
/// Derive a Differentiable type witness for a nominal type.
296307
///
297308
/// \returns the derived member, which will also be added to the type.

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5363,6 +5363,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
53635363
case KnownProtocolKind::Differentiable:
53645364
return derived.deriveDifferentiable(Requirement);
53655365

5366+
// SWIFT_ENABLE_TENSORFLOW
5367+
case KnownProtocolKind::EuclideanDifferentiable:
5368+
return derived.deriveEuclideanDifferentiable(Requirement);
5369+
53665370
default:
53675371
return nullptr;
53685372
}

stdlib/public/core/AutoDiff.swift

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,12 @@ public extension Differentiable where TangentVector == Self {
229229
/// `TangentVector` is equal to its vector space component.
230230
public protocol EuclideanDifferentiable: Differentiable {
231231
/// The differentiable vector component of `self`.
232-
var vectorView: TangentVector { get set }
232+
var vectorView: TangentVector { get }
233233
}
234234

235-
public extension EuclideanDifferentiable where TangentVector == Self {
236-
var vectorView: TangentVector {
237-
_read { yield self }
238-
_modify { yield &self }
239-
}
235+
public extension EuclideanDifferentiable
236+
where TangentVector: EuclideanDifferentiable, TangentVector == Self {
237+
var vectorView: TangentVector { _read { yield self } }
240238
}
241239

242240
/// Returns `x` like an identity function. When used in a context where `x` is

test/AutoDiff/derived_conformances.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,13 @@ DerivedConformanceTests.test("MemberwiseInitializers") {
3030
HasNoDerivativeConstant.TangentVector.zero)
3131
}
3232

33+
DerivedConformanceTests.test("EuclideanVectorView") {
34+
struct Foo: EuclideanDifferentiable {
35+
var x: SIMD4<Float>
36+
@noDerivative var y: SIMD4<Int32>
37+
}
38+
let x = Foo(x: [1, 2, 3, 4], y: .zero)
39+
expectEqual(Foo.TangentVector(x: [1, 2, 3, 4]), x.vectorView)
40+
}
41+
3342
runAllTests()

test/AutoDiff/derived_differentiable.swift

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,51 @@
22
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN
33
// RUN: %target-swift-frontend -emit-sil -verify %s
44

5-
struct PointwiseMultiplicativeDummy : Differentiable, PointwiseMultiplicative {}
5+
struct PointwiseMultiplicativeDummy : EuclideanDifferentiable, PointwiseMultiplicative {}
66

7-
public struct Foo : Differentiable {
7+
public struct Foo : EuclideanDifferentiable {
88
public var a: Float
99
}
1010

11-
// CHECK-AST-LABEL: public struct Foo : Differentiable {
11+
// CHECK-AST-LABEL: public struct Foo : EuclideanDifferentiable {
1212
// CHECK-AST: @differentiable
1313
// CHECK-AST: public var a: Float
1414
// CHECK-AST: internal init(a: Float)
1515
// CHECK-AST: public struct TangentVector
1616
// CHECK-AST: public typealias TangentVector = Foo.TangentVector
17+
// CHECK-AST: public var vectorView: Foo.TangentVector { get }
1718

1819
// CHECK-SILGEN-LABEL: // Foo.a.getter
1920
// CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s22derived_differentiable3FooV1aSfvg : $@convention(method) (Foo) -> Float
2021

21-
struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
22+
struct AdditiveTangentIsSelf : AdditiveArithmetic, EuclideanDifferentiable {
2223
var a: Float
2324
var dummy: PointwiseMultiplicativeDummy
2425
}
2526
let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in
2627
x.a + x.a
2728
}
2829

29-
// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
30+
// CHECK-AST-LABEL: internal struct AdditiveTangentIsSelf : AdditiveArithmetic, EuclideanDifferentiable {
3031
// CHECK-AST: internal var a: Float
3132
// CHECK-AST: internal var dummy: PointwiseMultiplicativeDummy
3233
// CHECK-AST: internal init(a: Float, dummy: PointwiseMultiplicativeDummy)
3334
// CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf
35+
// The following should not exist because when `Self == Self.TangentVector`, `vectorView` is not synthesized.
36+
// CHECK-AST-NOT: internal var vectorView: AdditiveTangentIsSelf { get }
3437

35-
struct TestNoDerivative : Differentiable {
38+
struct TestNoDerivative : EuclideanDifferentiable {
3639
var w: Float
3740
@noDerivative var technicallyDifferentiable: Float
3841
}
3942

40-
// CHECK-AST-LABEL: internal struct TestNoDerivative : Differentiable {
43+
// CHECK-AST-LABEL: internal struct TestNoDerivative : EuclideanDifferentiable {
4144
// CHECK-AST: var w: Float
4245
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
4346
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
4447
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
4548
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
49+
// CHECK-AST: internal var vectorView: TestNoDerivative.TangentVector { get }
4650

4751
struct TestPointwiseMultiplicative : Differentiable {
4852
var w: PointwiseMultiplicativeDummy

0 commit comments

Comments
 (0)