Skip to content

Commit e5e9fce

Browse files
authored
[AutoDiff upstream] Upstream attribute type-checking changes. (#29945)
- Support `@differentiable` and `@derivative` attributes for original initializers in final classes. Reject original initializers in non-final classes. - Synchronize tests.
1 parent 8c2cff8 commit e5e9fce

File tree

4 files changed

+109
-28
lines changed

4 files changed

+109
-28
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,13 +2971,12 @@ ERROR(differentiable_attr_protocol_req_assoc_func,none,
29712971
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
29722972
"'@differentiable' attribute on stored property cannot specify "
29732973
"'jvp:' or 'vjp:'", ())
2974-
ERROR(differentiable_attr_class_member_no_dynamic_self,none,
2975-
"'@differentiable' attribute cannot be declared on class methods "
2974+
ERROR(differentiable_attr_class_member_dynamic_self_result_unsupported,none,
2975+
"'@differentiable' attribute cannot be declared on class members "
29762976
"returning 'Self'", ())
2977-
// TODO(TF-654): Remove when differentiation supports class initializers.
2978-
ERROR(differentiable_attr_class_init_not_yet_supported,none,
2979-
"'@differentiable' attribute does not yet support class initializers",
2980-
())
2977+
ERROR(differentiable_attr_nonfinal_class_init_unsupported,none,
2978+
"'@differentiable' attribute cannot be declared on 'init' in a non-final "
2979+
"class; consider making %0 final", (Type))
29812980
ERROR(differentiable_attr_empty_where_clause,none,
29822981
"empty 'where' clause in '@differentiable' attribute", ())
29832982
ERROR(differentiable_attr_where_clause_for_nongeneric_original,none,
@@ -3016,6 +3015,12 @@ ERROR(derivative_attr_not_in_same_file_as_original,none,
30163015
"derivative not in the same file as the original function", ())
30173016
ERROR(derivative_attr_original_stored_property_unsupported,none,
30183017
"cannot register derivative for stored property %0", (DeclNameRef))
3018+
ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
3019+
"cannot register derivative for class member %0 returning 'Self'",
3020+
(DeclNameRef))
3021+
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
3022+
"cannot register derivative for 'init' in a non-final class; consider "
3023+
"making %0 final", (Type))
30193024
ERROR(derivative_attr_original_already_has_derivative,none,
30203025
"a derivative already exists for %0", (DeclName))
30213026
NOTE(derivative_attr_duplicate_note,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,32 +3886,39 @@ llvm::Expected<IndexSubset *> DifferentiableAttributeTypeCheckRequest::evaluate(
38863886
return nullptr;
38873887
}
38883888

3889+
// Diagnose if original function is an invalid class member.
38893890
bool isOriginalClassMember = original->getDeclContext() &&
38903891
original->getDeclContext()->getSelfClassDecl();
3891-
3892-
// Diagnose if original function is an invalid class member.
38933892
if (isOriginalClassMember) {
3894-
// Class methods returning dynamic `Self` are not supported.
3895-
// (For class methods, dynamic `Self` is supported only as the single
3896-
// result - tuple-returning JVPs/VJPs would not type-check.)
3897-
if (auto *originalFn = dyn_cast<FuncDecl>(original)) {
3898-
if (originalFn->hasDynamicSelfResult()) {
3899-
diags.diagnose(attr->getLocation(),
3900-
diag::differentiable_attr_class_member_no_dynamic_self);
3893+
auto *classDecl = original->getDeclContext()->getSelfClassDecl();
3894+
assert(classDecl);
3895+
// Class members returning dynamic `Self` are not supported.
3896+
// Dynamic `Self` is supported only as a single top-level result for class
3897+
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
3898+
// type-check.
3899+
bool diagnoseDynamicSelfResult = original->hasDynamicSelfResult();
3900+
if (diagnoseDynamicSelfResult) {
3901+
// Diagnose class initializers in non-final classes.
3902+
if (isa<ConstructorDecl>(original)) {
3903+
if (!classDecl->isFinal()) {
3904+
diags.diagnose(
3905+
attr->getLocation(),
3906+
diag::differentiable_attr_nonfinal_class_init_unsupported,
3907+
classDecl->getDeclaredInterfaceType());
3908+
attr->setInvalid();
3909+
return nullptr;
3910+
}
3911+
}
3912+
// Diagnose all other declarations returning dynamic `Self`.
3913+
else {
3914+
diags.diagnose(
3915+
attr->getLocation(),
3916+
diag::
3917+
differentiable_attr_class_member_dynamic_self_result_unsupported);
39013918
attr->setInvalid();
39023919
return nullptr;
39033920
}
39043921
}
3905-
3906-
// TODO(TF-654): Class initializers are not yet supported.
3907-
// Extra JVP/VJP type calculation logic is necessary because classes have
3908-
// both allocators and initializers.
3909-
if (isa<ConstructorDecl>(original)) {
3910-
diags.diagnose(attr->getLocation(),
3911-
diag::differentiable_attr_class_init_not_yet_supported);
3912-
attr->setInvalid();
3913-
return nullptr;
3914-
}
39153922
}
39163923

39173924
// Resolve the derivative generic signature.
@@ -4201,6 +4208,38 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
42014208
return true;
42024209
}
42034210
}
4211+
// Diagnose if original function is an invalid class member.
4212+
bool isOriginalClassMember =
4213+
originalAFD->getDeclContext() &&
4214+
originalAFD->getDeclContext()->getSelfClassDecl();
4215+
if (isOriginalClassMember) {
4216+
auto *classDecl = originalAFD->getDeclContext()->getSelfClassDecl();
4217+
assert(classDecl);
4218+
// Class members returning dynamic `Self` are not supported.
4219+
// Dynamic `Self` is supported only as a single top-level result for class
4220+
// members. JVP/VJP functions returning `(Self, ...)` tuples would not
4221+
// type-check.
4222+
bool diagnoseDynamicSelfResult = originalAFD->hasDynamicSelfResult();
4223+
if (diagnoseDynamicSelfResult) {
4224+
// Diagnose class initializers in non-final classes.
4225+
if (isa<ConstructorDecl>(originalAFD)) {
4226+
if (!classDecl->isFinal()) {
4227+
diags.diagnose(attr->getLocation(),
4228+
diag::derivative_attr_nonfinal_class_init_unsupported,
4229+
classDecl->getDeclaredInterfaceType());
4230+
return true;
4231+
}
4232+
}
4233+
// Diagnose all other declarations returning dynamic `Self`.
4234+
else {
4235+
diags.diagnose(
4236+
attr->getLocation(),
4237+
diag::derivative_attr_class_member_dynamic_self_result_unsupported,
4238+
DeclNameRef(originalAFD->getFullName()));
4239+
return true;
4240+
}
4241+
}
4242+
}
42044243
attr->setOriginalFunction(originalAFD);
42054244

42064245
// Get the resolved differentiability parameter indices.

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,3 +598,29 @@ where Self: Differentiable, Self == Self.TangentVector {
598598
return (x + y, { v in (v, v) })
599599
}
600600
}
601+
602+
// Test derivatives of default implementations.
603+
protocol HasADefaultImplementation {
604+
func req(_ x: Float) -> Float
605+
}
606+
extension HasADefaultImplementation {
607+
func req(_ x: Float) -> Float { x }
608+
// ok
609+
@derivative(of: req)
610+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
611+
(x, { 10 * $0 })
612+
}
613+
}
614+
615+
// Test default derivatives of requirements.
616+
protocol HasADefaultDerivative {
617+
func req(_ x: Float) -> Float
618+
}
619+
extension HasADefaultDerivative {
620+
// TODO(TF-982): Make this ok.
621+
// expected-error @+1 {{could not find function 'req'}}
622+
@derivative(of: req)
623+
func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
624+
(x, { 10 * $0 })
625+
}
626+
}

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,8 +1079,7 @@ class Super: Differentiable {
10791079

10801080
var base: Float
10811081

1082-
// NOTE(TF-654): Class initializers are not yet supported.
1083-
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
1082+
// expected-error @+1 {{'@differentiable' attribute cannot be declared on 'init' in a non-final class; consider making 'Super' final}}
10841083
@differentiable
10851084
init(base: Float) {
10861085
self.base = base
@@ -1123,7 +1122,7 @@ class Super: Differentiable {
11231122
func instanceMethod<T>(_ x: Float, y: T) -> Float { x }
11241123

11251124
// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
1126-
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}}
1125+
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class members returning 'Self'}}
11271126
@differentiable(vjp: vjpDynamicSelfResult)
11281127
func dynamicSelfResult() -> Self { self }
11291128

@@ -1147,6 +1146,18 @@ class Sub: Super {
11471146
override func testSuperclassDerivatives(_ x: Float) -> Float { x }
11481147
}
11491148

1149+
final class FinalClass: Differentiable {
1150+
typealias TangentVector = DummyTangentVector
1151+
func move(along _: TangentVector) {}
1152+
1153+
var base: Float
1154+
1155+
@differentiable
1156+
init(base: Float) {
1157+
self.base = base
1158+
}
1159+
}
1160+
11501161
// Test unsupported accessors: `set`, `_read`, `_modify`.
11511162

11521163
struct UnsupportedAccessors: Differentiable {

0 commit comments

Comments
 (0)