Skip to content

[AutoDiff upstream] Upstream attribute type-checking changes. #29945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2949,13 +2949,12 @@ ERROR(differentiable_attr_protocol_req_assoc_func,none,
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
"'@differentiable' attribute on stored property cannot specify "
"'jvp:' or 'vjp:'", ())
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
"'@differentiable' attribute cannot be declared on class methods "
ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none,
"'@differentiable' attribute cannot be declared on class members "
"returning 'Self'", ())
// TODO(TF-654): Remove when differentiation supports class initializers.
ERROR(differentiable_attr_class_init_not_yet_supported,none,
"'@differentiable' attribute does not yet support class initializers",
())
ERROR(differentiable_attr_nonfinal_class_init_unsupported,none,
"'@differentiable' attribute cannot be declared on 'init' in a non-final "
"class; consider making %0 final", (Type))
ERROR(differentiable_attr_empty_where_clause,none,
"empty 'where' clause in '@differentiable' attribute", ())
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
Expand Down Expand Up @@ -2994,6 +2993,12 @@ ERROR(derivative_attr_not_in_same_file_as_original,none,
"derivative not in the same file as the original function", ())
ERROR(derivative_attr_original_stored_property_unsupported,none,
"cannot register derivative for stored property %0", (DeclNameRef))
ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
"cannot register derivative for class member %0 returning 'Self'",
(DeclNameRef))
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
"cannot register derivative for 'init' in a non-final class; consider "
"making %0 final", (Type))
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
NOTE(derivative_attr_duplicate_note,none,
Expand Down
77 changes: 58 additions & 19 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3872,32 +3872,39 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
return nullptr;
}

// Diagnose if original function is an invalid class member.
bool isOriginalClassMember = original->getDeclContext() &&
original->getDeclContext()->getSelfClassDecl();

// Diagnose if original function is an invalid class member.
if (isOriginalClassMember) {
// Class methods returning dynamic `Self` are not supported.
// (For class methods, dynamic `Self` is supported only as the single
// result - tuple-returning JVPs/VJPs would not type-check.)
if (auto *originalFn = dyn_cast<FuncDecl>(original)) {
if (originalFn->hasDynamicSelfResult()) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_member_no_dynamic_self);
auto *classDecl = original->getDeclContext()->getSelfClassDecl();
assert(classDecl);
// Class members returning dynamic `Self` are not supported.
// Dynamic `Self` is supported only as a single top-level result for class
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
// type-check.
bool diagnoseDynamicSelfResult = original->hasDynamicSelfResult();
if (diagnoseDynamicSelfResult) {
// Diagnose class initializers in non-final classes.
if (isa<ConstructorDecl>(original)) {
if (!classDecl->isFinal()) {
diags.diagnose(
attr->getLocation(),
diag::differentiable_attr_nonfinal_class_init_unsupported,
classDecl->getDeclaredInterfaceType());
attr->setInvalid();
return nullptr;
}
}
// Diagnose all other declarations returning dynamic `Self`.
else {
diags.diagnose(
attr->getLocation(),
diag::
differentiable_attr_class_member_dynamic_self_result_unsupported);
attr->setInvalid();
return nullptr;
}
}

// TODO(TF-654): Class initializers are not yet supported.
// Extra JVP/VJP type calculation logic is necessary because classes have
// both allocators and initializers.
if (isa<ConstructorDecl>(original)) {
diags.diagnose(attr->getLocation(),
diag::differentiable_attr_class_init_not_yet_supported);
attr->setInvalid();
return nullptr;
}
}

// Resolve the derivative generic signature.
Expand Down Expand Up @@ -4187,6 +4194,38 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}
}
// Diagnose if original function is an invalid class member.
bool isOriginalClassMember =
originalAFD->getDeclContext() &&
originalAFD->getDeclContext()->getSelfClassDecl();
if (isOriginalClassMember) {
auto *classDecl = originalAFD->getDeclContext()->getSelfClassDecl();
assert(classDecl);
// Class members returning dynamic `Self` are not supported.
// Dynamic `Self` is supported only as a single top-level result for class
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
// type-check.
bool diagnoseDynamicSelfResult = originalAFD->hasDynamicSelfResult();
if (diagnoseDynamicSelfResult) {
// Diagnose class initializers in non-final classes.
if (isa<ConstructorDecl>(originalAFD)) {
if (!classDecl->isFinal()) {
diags.diagnose(attr->getLocation(),
diag::derivative_attr_nonfinal_class_init_unsupported,
classDecl->getDeclaredInterfaceType());
return true;
}
}
// Diagnose all other declarations returning dynamic `Self`.
else {
diags.diagnose(
attr->getLocation(),
diag::derivative_attr_class_member_dynamic_self_result_unsupported,
DeclNameRef(originalAFD->getFullName()));
return true;
}
}
}
attr->setOriginalFunction(originalAFD);

// Get the resolved differentiability parameter indices.
Expand Down
26 changes: 26 additions & 0 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,29 @@ where Self: Differentiable, Self == Self.TangentVector {
return (x + y, { v in (v, v) })
}
}

// Test derivatives of default implementations.
protocol HasADefaultImplementation {
func req(_ x: Float) -> Float
}
extension HasADefaultImplementation {
func req(_ x: Float) -> Float { x }
// ok
@derivative(of: req)
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 10 * $0 })
}
}

// Test default derivatives of requirements.
protocol HasADefaultDerivative {
func req(_ x: Float) -> Float
}
extension HasADefaultDerivative {
// TODO(TF-982): Make this ok.
// expected-error @+1 {{could not find function 'req'}}
@derivative(of: req)
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
(x, { 10 * $0 })
}
}
17 changes: 14 additions & 3 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,7 @@ class Super: Differentiable {

var base: Float

// NOTE(TF-654): Class initializers are not yet supported.
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on 'init' in a non-final class; consider making 'Super' final}}
@differentiable
init(base: Float) {
self.base = base
Expand Down Expand Up @@ -1123,7 +1122,7 @@ class Super: Differentiable {
func instanceMethod<T>(_ x: Float, y: T) -> Float { x }

// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}}
@differentiable(vjp: vjpDynamicSelfResult)
func dynamicSelfResult() -> Self { self }

Expand All @@ -1147,6 +1146,18 @@ class Sub: Super {
override func testSuperclassDerivatives(_ x: Float) -> Float { x }
}

final class FinalClass: Differentiable {
typealias TangentVector = DummyTangentVector
func move(along _: TangentVector) {}

var base: Float

@differentiable
init(base: Float) {
self.base = base
}
}

// Test unsupported accessors: `set`, `_read`, `_modify`.

struct UnsupportedAccessors: Differentiable {
Expand Down