Skip to content

Commit f81aba7

Browse files
authored
[AutoDiff][SR-13152] Better diagnostic for static decl modifier mismatch. (#36128)
Improved diagnostic for when the registered derivative function and the function it derivates (the original) differ in terms of `static` declaration modifier usage. Suggesting as well a fix-it, to either remove or add the `static` keyword. The registered derivative needs to be marked as `static` in two cases: 1. When the original function is a constructor. 2. When the original function is static as well. When the original function is an instance method, the registered derivative must be as well. Resolves [SR-13152](https://bugs.swift.org/browse/SR-13152).
1 parent 73ea1f1 commit f81aba7

File tree

4 files changed

+132
-22
lines changed

4 files changed

+132
-22
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,6 +3232,17 @@ NOTE(derivative_attr_fix_access,none,
32323232
"mark the derivative function as "
32333233
"'%select{private|fileprivate|internal|@usableFromInline|@usableFromInline}0' "
32343234
"to match the original function", (AccessLevel))
3235+
ERROR(derivative_attr_static_method_mismatch_original,none,
3236+
"unexpected derivative function declaration; "
3237+
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",
3238+
(/*original*/DeclName, /*derivative*/ DeclName,
3239+
/*originalIsStatic*/bool))
3240+
NOTE(derivative_attr_static_method_mismatch_original_note,none,
3241+
"original function %0 is %select{an instance|a 'static'}1 method",
3242+
(/*original*/ DeclName, /*originalIsStatic*/bool))
3243+
NOTE(derivative_attr_static_method_mismatch_fix,none,
3244+
"make derivative function %0 %select{an instance|a 'static'}1 method",
3245+
(/*derivative*/ DeclName, /*mustBeStatic*/bool))
32353246

32363247
// @transpose
32373248
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
@@ -3249,6 +3260,17 @@ ERROR(transpose_attr_wrt_self_must_be_static,none,
32493260
NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
32503261
"the transpose is declared in %0 but the original function is declared in "
32513262
"%1", (Type, Type))
3263+
ERROR(transpose_attr_static_method_mismatch_original,none,
3264+
"unexpected transpose function declaration; "
3265+
"%0 requires the transpose function %1 to be %select{an instance|a 'static'}2 method",
3266+
(/*original*/DeclName, /*transpose*/ DeclName,
3267+
/*originalIsStatic*/bool))
3268+
NOTE(transpose_attr_static_method_mismatch_original_note,none,
3269+
"original function %0 is %select{an instance|a 'static'}1 method",
3270+
(/*original*/ DeclName, /*originalIsStatic*/bool))
3271+
NOTE(transpose_attr_static_method_mismatch_fix,none,
3272+
"make transpose function %0 %select{an instance|a 'static'}1 method",
3273+
(/*transpose*/ DeclName, /*mustBeStatic*/bool))
32523274

32533275
// Automatic differentiation attributes
32543276
ERROR(autodiff_attr_original_decl_ambiguous,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4028,7 +4028,9 @@ static bool checkFunctionSignature(
40284028
if (!std::equal(required->getParams().begin(), required->getParams().end(),
40294029
candidateFnTy->getParams().begin(),
40304030
[&](AnyFunctionType::Param x, AnyFunctionType::Param y) {
4031-
return x.getOldType()->isEqual(mapType(y.getOldType()));
4031+
auto xInstanceTy = x.getOldType()->getMetatypeInstanceType();
4032+
auto yInstanceTy = y.getOldType()->getMetatypeInstanceType();
4033+
return xInstanceTy->isEqual(mapType(yInstanceTy));
40324034
}))
40334035
return false;
40344036

@@ -4827,8 +4829,41 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
48274829
}
48284830
}
48294831
}
4832+
48304833
attr->setOriginalFunction(originalAFD);
48314834

4835+
// Returns true if:
4836+
// - Original function and derivative function are static methods.
4837+
// - Original function and derivative function are non-static methods.
4838+
// - Original function is a Constructor declaration and derivative function is
4839+
// a static method.
4840+
auto compatibleStaticDecls = [&]() {
4841+
return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) ==
4842+
derivative->isStatic();
4843+
};
4844+
4845+
// Diagnose if original function and derivative differ in terms of static declaration.
4846+
if (!compatibleStaticDecls()) {
4847+
bool derivativeMustBeStatic = !derivative->isStatic();
4848+
diags.diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
4849+
diag::derivative_attr_static_method_mismatch_original,
4850+
originalAFD->getName(), derivative->getName(),
4851+
derivativeMustBeStatic);
4852+
diags.diagnose(originalAFD->getNameLoc(),
4853+
diag::derivative_attr_static_method_mismatch_original_note,
4854+
originalAFD->getName(), derivativeMustBeStatic);
4855+
auto fixItDiag =
4856+
diags.diagnose(derivative->getStartLoc(),
4857+
diag::derivative_attr_static_method_mismatch_fix,
4858+
derivative->getName(), derivativeMustBeStatic);
4859+
if (derivativeMustBeStatic) {
4860+
fixItDiag.fixItInsert(derivative->getStartLoc(), "static ");
4861+
} else {
4862+
fixItDiag.fixItRemove(derivative->getStaticLoc());
4863+
}
4864+
return true;
4865+
}
4866+
48324867
// Returns true if:
48334868
// - Original function and derivative function have the same access level.
48344869
// - Original function is public and derivative function is internal
@@ -5231,8 +5266,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
52315266
// If the transpose function is curried and `self` is a linearity parameter,
52325267
// check that the instance and static `Self` types are equal.
52335268
Type staticSelfType, instanceSelfType;
5269+
bool doSelfTypesMatch = false;
52345270
if (isCurried && wrtSelf) {
5235-
bool doSelfTypesMatch = doTransposeStaticAndInstanceSelfTypesMatch(
5271+
doSelfTypesMatch = doTransposeStaticAndInstanceSelfTypesMatch(
52365272
transposeInterfaceType, staticSelfType, instanceSelfType);
52375273
if (!doSelfTypesMatch) {
52385274
diagnose(attr->getLocation(),
@@ -5343,6 +5379,37 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
53435379
return;
53445380
}
53455381

5382+
// Returns true if:
5383+
// - Original function and transpose function are static methods.
5384+
// - Original function and transpose function are non-static methods.
5385+
// - Original function is a Constructor declaration and transpose function is
5386+
// a static method.
5387+
auto compatibleStaticDecls = [&]() {
5388+
return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) ==
5389+
transpose->isStatic();
5390+
};
5391+
5392+
// Diagnose if original function and transpose differ in terms of static declaration.
5393+
if (!doSelfTypesMatch && !compatibleStaticDecls()) {
5394+
bool transposeMustBeStatic = !transpose->isStatic();
5395+
diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
5396+
diag::transpose_attr_static_method_mismatch_original,
5397+
originalAFD->getName(), transpose->getName(),
5398+
transposeMustBeStatic);
5399+
diagnose(originalAFD->getNameLoc(),
5400+
diag::transpose_attr_static_method_mismatch_original_note,
5401+
originalAFD->getName(), transposeMustBeStatic);
5402+
auto fixItDiag = diagnose(transpose->getStartLoc(),
5403+
diag::transpose_attr_static_method_mismatch_fix,
5404+
transpose->getName(), transposeMustBeStatic);
5405+
if (transposeMustBeStatic) {
5406+
fixItDiag.fixItInsert(transpose->getStartLoc(), "static ");
5407+
} else {
5408+
fixItDiag.fixItRemove(transpose->getStaticLoc());
5409+
}
5410+
return;
5411+
}
5412+
53465413
// Set the resolved linearity parameter indices in the attribute.
53475414
attr->setParameterIndices(linearParamIndices);
53485415
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -669,21 +669,6 @@ func jvpInvalid<T: Differentiable>(x: T) -> (
669669
return (x, { $0 })
670670
}
671671

672-
// Test invalid derivative type context: instance vs static method mismatch.
673-
674-
struct InvalidTypeContext<T: Differentiable> {
675-
// expected-note @+1 {{candidate static method does not have type equal to or less constrained than '<T where T : Differentiable> (InvalidTypeContext<T>) -> (T) -> T'}}
676-
static func staticMethod(_ x: T) -> T { x }
677-
678-
// expected-error @+1 {{referenced declaration 'staticMethod' could not be resolved}}
679-
@derivative(of: staticMethod)
680-
func jvpStatic(_ x: T) -> (
681-
value: T, differential: (T.TangentVector) -> (T.TangentVector)
682-
) {
683-
return (x, { $0 })
684-
}
685-
}
686-
687672
// Test stored property original declaration.
688673

689674
struct HasStoredProperty {
@@ -1165,3 +1150,41 @@ func opaqueResult(_ x: Float) -> some Differentiable { x }
11651150
func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
11661151
fatalError()
11671152
}
1153+
1154+
// Test instance vs static method mismatch.
1155+
1156+
struct StaticMismatch<T: Differentiable> {
1157+
// expected-note @+1 {{original function 'init(_:)' is a 'static' method}}
1158+
init(_ x: T) {}
1159+
// expected-note @+1 {{original function 'instanceMethod' is an instance method}}
1160+
func instanceMethod(_ x: T) -> T { x }
1161+
// expected-note @+1 {{original function 'staticMethod' is a 'static' method}}
1162+
static func staticMethod(_ x: T) -> T { x }
1163+
1164+
// expected-error @+1 {{unexpected derivative function declaration; 'init(_:)' requires the derivative function 'vjpInit' to be a 'static' method}}
1165+
@derivative(of: init)
1166+
// expected-note @+1 {{make derivative function 'vjpInit' a 'static' method}}{{3-3=static }}
1167+
func vjpInit(_ x: T) -> (value: Self, pullback: (T.TangentVector) -> T.TangentVector) {
1168+
fatalError()
1169+
}
1170+
1171+
// expected-error @+1 {{unexpected derivative function declaration; 'instanceMethod' requires the derivative function 'jvpInstance' to be an instance method}}
1172+
@derivative(of: instanceMethod)
1173+
// expected-note @+1 {{make derivative function 'jvpInstance' an instance method}}{{3-10=}}
1174+
static func jvpInstance(_ x: T) -> (
1175+
value: T, differential: (T.TangentVector) -> (T.TangentVector)
1176+
) {
1177+
return (x, { $0 })
1178+
}
1179+
1180+
// expected-error @+1 {{unexpected derivative function declaration; 'staticMethod' requires the derivative function 'jvpStatic' to be a 'static' method}}
1181+
@derivative(of: staticMethod)
1182+
// expected-note @+1 {{make derivative function 'jvpStatic' a 'static' method}}{{3-3=static }}
1183+
func jvpStatic(_ x: T) -> (
1184+
value: T, differential: (T.TangentVector) -> (T.TangentVector)
1185+
) {
1186+
return (x, { $0 })
1187+
}
1188+
}
1189+
1190+

test/AutoDiff/Sema/transpose_attr_type_checking.swift

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,6 @@ extension Float {
474474

475475
// Test non-`func` original declarations.
476476

477-
// expected-note @+1 {{candidate initializer does not have type equal to or less constrained than '<T where T : Differentiable, T == T.TangentVector> (Struct<T>) -> (Float) -> Struct<T>'}}
478477
struct Struct<T> {}
479478
extension Struct: Equatable where T: Equatable {}
480479
extension Struct: Differentiable & AdditiveArithmetic
@@ -499,9 +498,8 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
499498

500499
// Test initializers.
501500
extension Struct {
502-
// expected-note @+1 {{candidate initializer does not have type equal to or less constrained than '<T where T : Differentiable, T == T.TangentVector> (Struct<T>) -> (Float) -> Struct<T>'}}
501+
// expected-note @+1 {{original function 'init(_:)' is a 'static' method}}
503502
init(_ x: Float) {}
504-
// expected-note @+1 {{candidate initializer does not have type equal to or less constrained than '<T where T : Differentiable, T == T.TangentVector> (Struct<T>) -> (Float) -> Struct<T>'}}
505503
init(_ x: T, y: Float) {}
506504
}
507505

@@ -517,9 +515,9 @@ extension Struct where T: Differentiable, T == T.TangentVector {
517515
}
518516

519517
// Test instance transpose for static original initializer.
520-
// TODO(TF-1015): Add improved instance/static member mismatch error.
521-
// expected-error @+1 {{referenced declaration 'init' could not be resolved}}
518+
// expected-error @+1 {{unexpected transpose function declaration; 'init(_:)' requires the transpose function 'vjpInitStaticMismatch' to be a 'static' method}}
522519
@transpose(of: init, wrt: 0)
520+
// expected-note @+1 {{make transpose function 'vjpInitStaticMismatch' a 'static' method}}{{3-3=static }}
523521
func vjpInitStaticMismatch(_ x: Self) -> Float {
524522
fatalError()
525523
}

0 commit comments

Comments
 (0)