Skip to content

Commit 56f09db

Browse files
committed
addressing feedback on PR
* splitting diagnostic messages for `@derivative` and `@transpose` attr. * duplicating the diagnostic logic using lambda functions within the respective attr checking methods. * change wording of error/notes to use non-`static`|`static`.
1 parent df75fa4 commit 56f09db

File tree

4 files changed

+94
-64
lines changed

4 files changed

+94
-64
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3216,6 +3216,17 @@ NOTE(derivative_attr_fix_access,none,
32163216
"mark the derivative function as "
32173217
"'%select{private|fileprivate|internal|@usableFromInline|@usableFromInline}0' "
32183218
"to match the original function", (AccessLevel))
3219+
ERROR(derivative_attr_static_method_mismatch_original,none,
3220+
"unexpected derivative function declaration; "
3221+
"%0 requires the derivative function %1 to be %select{an instance|a 'static'}2 method",
3222+
(/*original*/DeclName, /*derivative*/ DeclName,
3223+
/*originalIsStatic*/bool))
3224+
NOTE(derivative_attr_static_method_mismatch_original_note,none,
3225+
"original function %0 is %select{an instance|a 'static'}1 method",
3226+
(/*original*/ DeclName, /*originalIsStatic*/bool))
3227+
NOTE(derivative_attr_static_method_mismatch_fix,none,
3228+
"make derivative function %0 %select{an instance|a 'static'}1 method",
3229+
(/*derivative*/ DeclName, /*mustBeStatic*/bool))
32193230

32203231
// @transpose
32213232
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,
@@ -3233,6 +3244,17 @@ ERROR(transpose_attr_wrt_self_must_be_static,none,
32333244
NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
32343245
"the transpose is declared in %0 but the original function is declared in "
32353246
"%1", (Type, Type))
3247+
ERROR(transpose_attr_static_method_mismatch_original,none,
3248+
"unexpected transpose function declaration; "
3249+
"%0 requires the transpose function %1 to be %select{an instance|a 'static'}2 method",
3250+
(/*original*/DeclName, /*transpose*/ DeclName,
3251+
/*originalIsStatic*/bool))
3252+
NOTE(transpose_attr_static_method_mismatch_original_note,none,
3253+
"original function %0 is %select{an instance|a 'static'}1 method",
3254+
(/*original*/ DeclName, /*originalIsStatic*/bool))
3255+
NOTE(transpose_attr_static_method_mismatch_fix,none,
3256+
"make transpose function %0 %select{an instance|a 'static'}1 method",
3257+
(/*transpose*/ DeclName, /*mustBeStatic*/bool))
32363258

32373259
// Automatic differentiation attributes
32383260
ERROR(autodiff_attr_original_decl_ambiguous,none,
@@ -3264,18 +3286,6 @@ ERROR(autodiff_attr_result_not_differentiable,none,
32643286
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
32653287
ERROR(autodiff_attr_opaque_result_type_unsupported,none,
32663288
"cannot differentiate functions returning opaque result types", ())
3267-
ERROR(autodiff_attr_static_decl_mismatch,none,
3268-
"derivative function %0 operates on %select{an instance|a}1 type, "
3269-
"not on %select{an instance|a}2 type as required",
3270-
(/*derivative*/ DeclName, /*derivativeOperatesOnTy*/bool,
3271-
/*originalOperatesOnTy*/bool))
3272-
NOTE(autodiff_attr_static_decl_original,none,
3273-
"original function %0 operates on %select{an instance|a}1 type",
3274-
(/*original*/ DeclName, /*originalOperatesOnTy*/bool))
3275-
NOTE(autodiff_attr_static_decl_mismatch_fix,none,
3276-
"derivative function %0 must %select{not be|be}1 'static'",
3277-
(/*derivative*/ DeclName, /*mustBeStatic*/bool))
3278-
32793289

32803290
// differentiation `wrt` parameters clause
32813291
ERROR(diff_function_no_parameters,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
#include "swift/AST/PropertyWrappers.h"
3535
#include "swift/AST/SourceFile.h"
3636
#include "swift/AST/StorageImpl.h"
37+
#include "swift/AST/TypeAlignments.h"
3738
#include "swift/AST/TypeCheckRequests.h"
3839
#include "swift/AST/Types.h"
40+
#include "swift/Basic/SourceLoc.h"
3941
#include "swift/Parse/Lexer.h"
4042
#include "swift/Sema/IDETypeChecking.h"
4143
#include "clang/Basic/CharInfo.h"
@@ -4583,39 +4585,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
45834585
(void)attr->getParameterIndices();
45844586
}
45854587

4586-
/// Checks if original candidate and registered derivative match in terms
4587-
/// of static declaration. If the original candidate is a constructor or is
4588-
/// defined as static method then the registered derivative is expected to be static.
4589-
/// Otherwise the registered derivative should be an instance method.
4590-
/// \returns true if mismatch is found, false otherwise.
4591-
static bool checkStaticDeclMismatch(AbstractFunctionDecl *originalCandidate,
4592-
AbstractFunctionDecl *registered) {
4593-
return (isa<ConstructorDecl>(originalCandidate) ||
4594-
originalCandidate->isStatic()) != registered->isStatic();
4595-
}
4596-
4597-
/// Produces diagnostics for mismatch in static/instance method declaration
4598-
/// between original candidate and registered derivative.
4599-
static void diagnoseStaticDeclMismatch(AbstractFunctionDecl *originalCandidate,
4600-
FuncDecl *registered) {
4601-
auto &diags = originalCandidate->getASTContext().Diags;
4602-
diags.diagnose(
4603-
registered->getNameLoc(), diag::autodiff_attr_static_decl_mismatch,
4604-
registered->getName(), registered->isStatic(), !registered->isStatic());
4605-
diags.diagnose(
4606-
originalCandidate->getNameLoc(), diag::autodiff_attr_static_decl_original,
4607-
originalCandidate->getName(),
4608-
isa<ConstructorDecl>(originalCandidate) || originalCandidate->isStatic());
4609-
auto fixItDiag = diags.diagnose(
4610-
registered->getStartLoc(), diag::autodiff_attr_static_decl_mismatch_fix,
4611-
registered->getName(), !registered->isStatic());
4612-
if (registered->isStatic()) {
4613-
fixItDiag.fixItRemove(registered->getStaticLoc());
4614-
} else {
4615-
fixItDiag.fixItInsert(registered->getStartLoc(), "static ");
4616-
}
4617-
}
4618-
46194588
/// Type-checks the given `@derivative` attribute `attr` on declaration `D`.
46204589
///
46214590
/// Effects are:
@@ -4842,12 +4811,39 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
48424811
}
48434812
}
48444813

4814+
attr->setOriginalFunction(originalAFD);
4815+
4816+
// Returns true if:
4817+
// - Original function and derivative function are static methods.
4818+
// - Original function and derivative function are non-static methods.
4819+
// - Original function is a Constructor declaration and derivative function is
4820+
// a static method.
4821+
auto compatibleStaticDecls = [&]() {
4822+
return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) ==
4823+
derivative->isStatic();
4824+
};
4825+
48454826
// Diagnose if original function and derivative differ in terms of static declaration.
4846-
if (checkStaticDeclMismatch(originalAFD, derivative)) {
4847-
diagnoseStaticDeclMismatch(originalAFD, derivative);
4827+
if (!compatibleStaticDecls()) {
4828+
bool derivativeMustBeStatic = !derivative->isStatic();
4829+
diags.diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
4830+
diag::derivative_attr_static_method_mismatch_original,
4831+
originalAFD->getName(), derivative->getName(),
4832+
derivativeMustBeStatic);
4833+
diags.diagnose(originalAFD->getNameLoc(),
4834+
diag::derivative_attr_static_method_mismatch_original_note,
4835+
originalAFD->getName(), derivativeMustBeStatic);
4836+
auto fixItDiag =
4837+
diags.diagnose(derivative->getStartLoc(),
4838+
diag::derivative_attr_static_method_mismatch_fix,
4839+
derivative->getName(), derivativeMustBeStatic);
4840+
if (derivativeMustBeStatic) {
4841+
fixItDiag.fixItInsert(derivative->getStartLoc(), "static ");
4842+
} else {
4843+
fixItDiag.fixItRemove(derivative->getStaticLoc());
4844+
}
48484845
return true;
48494846
}
4850-
attr->setOriginalFunction(originalAFD);
48514847

48524848
// Returns true if:
48534849
// - Original function and derivative function have the same access level.
@@ -5364,11 +5360,35 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
53645360
return;
53655361
}
53665362

5363+
// Returns true if:
5364+
// - Original function and transpose function are static methods.
5365+
// - Original function and transpose function are non-static methods.
5366+
// - Original function is a Constructor declaration and transpose function is
5367+
// a static method.
5368+
auto compatibleStaticDecls = [&]() {
5369+
return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) ==
5370+
transpose->isStatic();
5371+
};
5372+
53675373
// Diagnose if original function and transpose differ in terms of static declaration.
5368-
if (!doSelfTypesMatch && checkStaticDeclMismatch(originalAFD, transpose)) {
5369-
diagnoseStaticDeclMismatch(originalAFD, transpose);
5370-
attr->setInvalid();
5371-
return;
5374+
if (!doSelfTypesMatch && !compatibleStaticDecls()) {
5375+
bool transposeMustBeStatic = !transpose->isStatic();
5376+
diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(),
5377+
diag::transpose_attr_static_method_mismatch_original,
5378+
originalAFD->getName(), transpose->getName(),
5379+
transposeMustBeStatic);
5380+
diagnose(originalAFD->getNameLoc(),
5381+
diag::transpose_attr_static_method_mismatch_original_note,
5382+
originalAFD->getName(), transposeMustBeStatic);
5383+
auto fixItDiag = diagnose(transpose->getStartLoc(),
5384+
diag::transpose_attr_static_method_mismatch_fix,
5385+
transpose->getName(), transposeMustBeStatic);
5386+
if (transposeMustBeStatic) {
5387+
fixItDiag.fixItInsert(transpose->getStartLoc(), "static ");
5388+
} else {
5389+
fixItDiag.fixItRemove(transpose->getStaticLoc());
5390+
}
5391+
return;
53725392
}
53735393

53745394
// Set the resolved linearity parameter indices in the attribute.

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,32 +1154,32 @@ func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
11541154
// Test instance vs static method mismatch.
11551155

11561156
struct StaticMismatch<T: Differentiable> {
1157-
// expected-note @+1 {{original function 'init(_:)' operates on a type}}
1157+
// expected-note @+1 {{original function 'init(_:)' is a 'static' method}}
11581158
init(_ x: T) {}
1159-
// expected-note @+1 {{original function 'instanceMethod' operates on an instance type}}
1159+
// expected-note @+1 {{original function 'instanceMethod' is an instance method}}
11601160
func instanceMethod(_ x: T) -> T { x }
1161-
// expected-note @+1 {{original function 'staticMethod' operates on a type}}
1161+
// expected-note @+1 {{original function 'staticMethod' is a 'static' method}}
11621162
static func staticMethod(_ x: T) -> T { x }
11631163

1164+
// expected-error @+1 {{unexpected derivative function declaration; 'init(_:)' requires the derivative function 'vjpInit' to be a 'static' method}}
11641165
@derivative(of: init)
1165-
// expected-error @+2 {{derivative function 'vjpInit' operates on an instance type, not on a type as required}}
1166-
// expected-note @+1 {{derivative function 'vjpInit' must be 'static'}}{{3-3=static }}
1166+
// expected-note @+1 {{make derivative function 'vjpInit' a 'static' method}}{{3-3=static }}
11671167
func vjpInit(_ x: T) -> (value: Self, pullback: (T.TangentVector) -> T.TangentVector) {
11681168
fatalError()
11691169
}
11701170

1171+
// expected-error @+1 {{unexpected derivative function declaration; 'instanceMethod' requires the derivative function 'jvpInstance' to be an instance method}}
11711172
@derivative(of: instanceMethod)
1172-
// expected-error @+2 {{derivative function 'jvpInstance' operates on a type, not on an instance type as required}}
1173-
// expected-note @+1 {{derivative function 'jvpInstance' must not be 'static'}}{{3-10=}}
1173+
// expected-note @+1 {{make derivative function 'jvpInstance' an instance method}}{{3-10=}}
11741174
static func jvpInstance(_ x: T) -> (
11751175
value: T, differential: (T.TangentVector) -> (T.TangentVector)
11761176
) {
11771177
return (x, { $0 })
11781178
}
11791179

1180+
// expected-error @+1 {{unexpected derivative function declaration; 'staticMethod' requires the derivative function 'jvpStatic' to be a 'static' method}}
11801181
@derivative(of: staticMethod)
1181-
// expected-error @+2 {{derivative function 'jvpStatic' operates on an instance type, not on a type as required}}
1182-
// expected-note @+1 {{derivative function 'jvpStatic' must be 'static'}}{{3-3=static }}
1182+
// expected-note @+1 {{make derivative function 'jvpStatic' a 'static' method}}{{3-3=static }}
11831183
func jvpStatic(_ x: T) -> (
11841184
value: T, differential: (T.TangentVector) -> (T.TangentVector)
11851185
) {

test/AutoDiff/Sema/transpose_attr_type_checking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic {
498498

499499
// Test initializers.
500500
extension Struct {
501-
// expected-note @+1 {{original function 'init(_:)' operates on a type}}
501+
// expected-note @+1 {{original function 'init(_:)' is a 'static' method}}
502502
init(_ x: Float) {}
503503
init(_ x: T, y: Float) {}
504504
}
@@ -515,9 +515,9 @@ extension Struct where T: Differentiable, T == T.TangentVector {
515515
}
516516

517517
// Test instance transpose for static original initializer.
518+
// expected-error @+1 {{unexpected transpose function declaration; 'init(_:)' requires the transpose function 'vjpInitStaticMismatch' to be a 'static' method}}
518519
@transpose(of: init, wrt: 0)
519-
// expected-error @+2 {{derivative function 'vjpInitStaticMismatch' operates on an instance type, not on a type as required}}
520-
// expected-note @+1 {{derivative function 'vjpInitStaticMismatch' must be 'static'}}{{3-3=static }}
520+
// expected-note @+1 {{make transpose function 'vjpInitStaticMismatch' a 'static' method}}{{3-3=static }}
521521
func vjpInitStaticMismatch(_ x: Self) -> Float {
522522
fatalError()
523523
}

0 commit comments

Comments
 (0)