Skip to content

Commit 66ede5a

Browse files
authored
[AutoDiff] Fix incorrectly sorted members of synthesized 'TangentVector'. (#36154)
Do not mark synthesized `TangentVector` members as synthesized so that the type checker won't sort them. We would like tangent vector members to have the same order as the properties in the parent declaration. Also add `typealias TangentVector = Self` to the synthesized `TangentVector` so that it will not need its own `Differentiable` conformance derivation. Resolves SR-14241 / rdar://74659803.
1 parent 1620f2c commit 66ede5a

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ using namespace swift;
4040
/// If the given property is a `var`, return true because `move(by:)` can be
4141
/// invoked regardless. Otherwise, return true if and only if the property's
4242
/// type's 'Differentiable.move(by:)' witness is non-mutating.
43-
static bool canInvokeMoveAlongOnProperty(
43+
static bool canInvokeMoveByOnProperty(
4444
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
4545
assert(diffableConformance && "Property must conform to 'Differentiable'");
4646
// `var` always supports `move(by:)` since it is mutable.
@@ -64,7 +64,7 @@ static void
6464
getStoredPropertiesForDifferentiation(
6565
NominalTypeDecl *nominal, DeclContext *DC,
6666
SmallVectorImpl<VarDecl *> &result,
67-
bool includeLetPropertiesWithNonmutatingMoveAlong = false) {
67+
bool includeLetPropertiesWithNonmutatingMoveBy = false) {
6868
auto &C = nominal->getASTContext();
6969
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
7070
for (auto *vd : nominal->getStoredProperties()) {
@@ -90,8 +90,8 @@ getStoredPropertiesForDifferentiation(
9090
// Skip `let` stored properties with a mutating `move(by:)` if requested.
9191
// `mutating func move(by:)` cannot be synthesized to update `let`
9292
// properties.
93-
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
94-
!canInvokeMoveAlongOnProperty(vd, conformance))
93+
if (!includeLetPropertiesWithNonmutatingMoveBy &&
94+
!canInvokeMoveByOnProperty(vd, conformance))
9595
continue;
9696
result.push_back(vd);
9797
}
@@ -450,9 +450,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
450450
auto *tangentProperty = new (C) VarDecl(
451451
member->isStatic(), member->getIntroducer(),
452452
/*NameLoc*/ SourceLoc(), member->getName(), structDecl);
453-
tangentProperty->setSynthesized();
454-
// Note: `tangentProperty` is not marked as implicit here, because that
455-
// incorrectly affects memberwise initializer synthesis.
453+
// Note: `tangentProperty` is not marked as implicit or synthesized here,
454+
// because that incorrectly affects memberwise initializer synthesis and
455+
// causes the type checker to not guarantee the order of these members.
456456
auto memberContextualType =
457457
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
458458
auto memberTanType =
@@ -507,20 +507,27 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
507507
}
508508
}
509509

510-
// If nominal type is `@_fixed_layout`, also mark `TangentVector` struct as
511-
// `@_fixed_layout`.
512-
if (nominal->getAttrs().hasAttribute<FixedLayoutAttr>())
513-
addFixedLayoutAttr(structDecl);
514-
515-
// If nominal type is `@frozen`, also mark `TangentVector` struct as
516-
// `@frozen`.
510+
// If nominal type is `@frozen`, also mark `TangentVector` struct.
517511
if (nominal->getAttrs().hasAttribute<FrozenAttr>())
518512
structDecl->getAttrs().add(new (C) FrozenAttr(/*implicit*/ true));
519-
520-
// If nominal type is `@usableFromInline`, also mark `TangentVector` struct as
521-
// `@usableFromInline`.
522-
if (nominal->getAttrs().hasAttribute<UsableFromInlineAttr>())
513+
514+
// Add `typealias TangentVector = Self` so that the `TangentVector` itself
515+
// won't need its own conformance derivation.
516+
auto *tangentEqualsSelfAlias = new (C) TypeAliasDecl(
517+
SourceLoc(), SourceLoc(), C.Id_TangentVector, SourceLoc(),
518+
/*GenericParams*/ nullptr, structDecl);
519+
tangentEqualsSelfAlias->setUnderlyingType(structDecl->getSelfTypeInContext());
520+
tangentEqualsSelfAlias->setAccess(structDecl->getFormalAccess());
521+
tangentEqualsSelfAlias->setImplicit();
522+
tangentEqualsSelfAlias->setSynthesized();
523+
structDecl->addMember(tangentEqualsSelfAlias);
524+
525+
// If nominal type is `@usableFromInline`, also mark `TangentVector` struct.
526+
if (nominal->getAttrs().hasAttribute<UsableFromInlineAttr>()) {
523527
structDecl->getAttrs().add(new (C) UsableFromInlineAttr(/*implicit*/ true));
528+
tangentEqualsSelfAlias->getAttrs().add(
529+
new (C) UsableFromInlineAttr(/*implicit*/ true));
530+
}
524531

525532
// The implicit memberwise constructor must be explicitly created so that it
526533
// can called in `AdditiveArithmetic` and `Differentiable` methods. Normally,
@@ -593,7 +600,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
593600
TypeChecker::conformsToProtocol(varType, diffableProto, nominal);
594601
// If stored property should not be diagnosed, continue.
595602
if (diffableConformance &&
596-
canInvokeMoveAlongOnProperty(vd, diffableConformance))
603+
canInvokeMoveByOnProperty(vd, diffableConformance))
597604
continue;
598605
// Otherwise, add an implicit `@noDerivative` attribute.
599606
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));

test/AutoDiff/Sema/DerivedConformances/derived_differentiable.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %target-swift-frontend -print-ast %s | %FileCheck %s --check-prefix=CHECK-AST
2+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SIL
23

34
import _Differentiation
45

@@ -156,3 +157,24 @@ extension TangentVectorP where Self == StructWithTangentVectorConstrained.Tangen
156157

157158
// CHECK-AST-LABEL: internal struct StructWithTangentVectorConstrained : TangentVectorConstrained {
158159
// CHECK-AST: internal struct TangentVector : {{(TangentVectorP, Differentiable, AdditiveArithmetic)|(TangentVectorP, AdditiveArithmetic, Differentiable)|(Differentiable, TangentVectorP, AdditiveArithmetic)|(AdditiveArithmetic, TangentVectorP, Differentiable)|(Differentiable, AdditiveArithmetic, TangentVectorP)|(AdditiveArithmetic, Differentiable, TangentVectorP)}} {
160+
161+
public struct SR14241Struct: Differentiable {
162+
public var simd: [Float]
163+
public var scalar: Float
164+
}
165+
166+
// CHECK-AST-LABEL: public struct SR14241Struct : Differentiable {
167+
// CHECK-AST: public var simd: [Float]
168+
// CHECK-AST: public var scalar: Float
169+
// CHECK-AST: struct TangentVector : AdditiveArithmetic, Differentiable {
170+
// CHECK-AST: var simd: Array<Float>.TangentVector
171+
// CHECK-AST: var scalar: Float
172+
173+
// CHECK-SIL-LABEL: public struct SR14241Struct : Differentiable {
174+
// CHECK-SIL: @differentiable(reverse, wrt: self)
175+
// CHECK-SIL: @_hasStorage public var simd: [Float] { get set }
176+
// CHECK-SIL: @differentiable(reverse, wrt: self)
177+
// CHECK-SIL: @_hasStorage public var scalar: Float { get set }
178+
// CHECK-SIL: struct TangentVector : AdditiveArithmetic, Differentiable {
179+
// CHECK-SIL: @_hasStorage var simd: Array<Float>.DifferentiableView { get set }
180+
// CHECK-SIL: @_hasStorage var scalar: Float { get set }

0 commit comments

Comments
 (0)