Skip to content

Commit 1db433c

Browse files
authored
[Sema] [AutoDiff] Derive @_fieldwiseProductSpace when tangent/cotangent is Self (#21737)
* [Sema] [AutoDiff] Derive `@_fieldwiseProductSpace` when tangent/cotangent is Self. * Remove mapTypeOutOfContext() * Use boilerplate code on synthesized typealias declaration. The crucial line is copying formal access from nominal, which should fix the failing test.
1 parent 841526c commit 1db433c

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ ERROR(autodiff_unsupported_type,none,
374374
"differentiating '%0' is not supported yet", (Type))
375375
ERROR(autodiff_function_not_differentiable,none,
376376
"function is not differentiable", ())
377+
// TODO: Change this to a note.
377378
ERROR(autodiff_property_not_differentiable,none,
378379
"property is not differentiable", ())
379380
ERROR(autodiff_expression_is_not_differentiable_error,none,

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,21 @@ deriveDifferentiable_VectorSpace(DerivedConformance &derived,
577577
parentDC, ConformanceCheckFlags::Used);
578578
// Return `Self` if conditions are met.
579579
if (allMembersVectorSpaceEqualsSelf && nominalConformsToAddArith) {
580-
return parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
580+
auto selfType =
581+
parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
582+
auto *aliasDecl = new (C) TypeAliasDecl(
583+
SourceLoc(), SourceLoc(), getVectorSpaceIdentifier(kind, C),
584+
SourceLoc(), {}, nominal);
585+
aliasDecl->setUnderlyingType(selfType);
586+
aliasDecl->setImplicit();
587+
aliasDecl->getAttrs().add(
588+
new (C) FieldwiseProductSpaceAttr(/*implicit*/ true));
589+
nominal->addMember(aliasDecl);
590+
aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
591+
aliasDecl->setValidationToChecked();
592+
TC.validateDecl(aliasDecl);
593+
C.addSynthesizedDecl(aliasDecl);
594+
return selfType;
581595
}
582596

583597
// Get or synthesize both `TangentVector` and `CotangentVector` structs at

test/AutoDiff/derived_differentiable_properties.swift

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1-
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
1+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-AST
2+
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN
3+
// RUN: %target-swift-frontend -emit-sil -verify %s
24

35
public struct Foo : Differentiable {
46
public var a: Float
57
}
68

7-
// CHECK-LABEL: public struct Foo : Differentiable {
8-
// CHECK: @sil_stored @differentiable()
9-
// CHECK: public var a: Float { get set }
9+
// CHECK-AST-LABEL: public struct Foo : Differentiable {
10+
// CHECK-AST: @sil_stored @differentiable()
11+
// CHECK-AST: public var a: Float { get set }
12+
13+
// CHECK-SILGEN-LABEL: // Foo.a.getter
14+
// CHECK-SILGEN: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float
15+
16+
struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
17+
var a: Float
18+
}
19+
let _: @autodiff (AdditiveTangentIsSelf) -> Float = { x in
20+
x.a + x.a
21+
}
22+
23+
// CHECK-AST-LABEL: struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
24+
// CHECK-AST-NOT: @differentiable
25+
// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = AdditiveTangentIsSelf
26+
// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = AdditiveTangentIsSelf
1027

11-
// CHECK-LABEL: // Foo.a.getter
12-
// CHECK: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float

0 commit comments

Comments
 (0)