Skip to content

[Sema] [AutoDiff] Derive @_fieldwiseProductSpace when tangent/cotangent is Self #21737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ ERROR(autodiff_unsupported_type,none,
"differentiating '%0' is not supported yet", (Type))
ERROR(autodiff_function_not_differentiable,none,
"function is not differentiable", ())
// TODO: Change this to a note.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to change this to a note?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, here's why. The DifferentiationInvoker infra keeps track of how each differentiation task is created. Only the diagnostic on the very bottom of the invocation stack should be an error.

func f0(x) {
  f1()
}
func f1(x) {
  f2(x)
}
func f2(x) {
  f3(x) // non-differentiable
}
derivative(of: f0)

If expression f3() is not differentiable, we should not emit an error at f3() because it's not the place where the user requested the derivative and the user should not be told to change the code here before knowing what's actually causing differentiation to occur. Instead, we emit notes along the invocation stack until we reach the user's differentiation request, which is always the bottom of the invocation stack.

func f0(x) {
  f1() // note: when differentiating this function call
}
func f1(x) {
  f2(x) // note: when differentiating this function call
}
func f2(x) {
  f3(x) // note: expression is not differentiable
}
derivative(of: f0) // error: f0 is not differentiable

As to why this cannot be changed to a note today: We used to have DifferentialOperator as a case in DifferentiationInvoker, and the differential operator was, as you know, #gradient. When we switched over to generalized differentiability™, there is no more compiler-known differential operator -- function conversion to @autodiff is the only formal way to trigger differentiation in expressions. We didn't add a case for function conversion in DifferentiationInvoker, so the user request is not recorded, and there is no error at the user request. If we change this error to a note today, the pass pipeline would assume there's no error, keep running other passes on partially/incorrectly transformed code, and cause crashers or undefined runtime behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: This definitely needs to be fixed. Let me know if you wanna make diagnostics better :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed info! Definitely interested in improving AD diagnostics sometime.

ERROR(autodiff_property_not_differentiable,none,
"property is not differentiable", ())
NOTE(autodiff_function_generic_functions_unsupported,none,
Expand Down
16 changes: 15 additions & 1 deletion lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,21 @@ deriveDifferentiable_VectorSpace(DerivedConformance &derived,
parentDC, ConformanceCheckFlags::Used);
// Return `Self` if conditions are met.
if (allMembersVectorSpaceEqualsSelf && nominalConformsToAddArith) {
return parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
auto selfType =
parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
auto *aliasDecl = new (C) TypeAliasDecl(
SourceLoc(), SourceLoc(), getVectorSpaceIdentifier(kind, C),
SourceLoc(), {}, nominal);
aliasDecl->setUnderlyingType(selfType);
aliasDecl->setImplicit();
aliasDecl->getAttrs().add(
new (C) FieldwiseProductSpaceAttr(/*implicit*/ true));
nominal->addMember(aliasDecl);
aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
aliasDecl->setValidationToChecked();
TC.validateDecl(aliasDecl);
C.addSynthesizedDecl(aliasDecl);
return selfType;
}

// Get or synthesize both `TangentVector` and `CotangentVector` structs at
Expand Down
27 changes: 21 additions & 6 deletions test/AutoDiff/derived_differentiable_properties.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-AST
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s --check-prefix=CHECK-SILGEN
// RUN: %target-swift-frontend -emit-sil -verify %s

public struct Foo : Differentiable {
public var a: Float
}

// CHECK-LABEL: public struct Foo : Differentiable {
// CHECK: @sil_stored @differentiable()
// CHECK: public var a: Float { get set }
// CHECK-AST-LABEL: public struct Foo : Differentiable {
// CHECK-AST: @sil_stored @differentiable()
// CHECK-AST: public var a: Float { get set }

// CHECK-SILGEN-LABEL: // Foo.a.getter
// CHECK-SILGEN: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float

struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
var a: Float
}
let _: @autodiff (AdditiveTangentIsSelf) -> Float = { x in
x.a + x.a
}

// CHECK-AST-LABEL: struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable {
// CHECK-AST-NOT: @differentiable
// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = AdditiveTangentIsSelf
// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = AdditiveTangentIsSelf

// CHECK-LABEL: // Foo.a.getter
// CHECK: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float