|
34 | 34 | #include "swift/AST/PropertyWrappers.h"
|
35 | 35 | #include "swift/AST/SourceFile.h"
|
36 | 36 | #include "swift/AST/StorageImpl.h"
|
| 37 | +#include "swift/AST/TypeAlignments.h" |
37 | 38 | #include "swift/AST/TypeCheckRequests.h"
|
38 | 39 | #include "swift/AST/Types.h"
|
| 40 | +#include "swift/Basic/SourceLoc.h" |
39 | 41 | #include "swift/Parse/Lexer.h"
|
40 | 42 | #include "swift/Sema/IDETypeChecking.h"
|
41 | 43 | #include "clang/Basic/CharInfo.h"
|
@@ -4583,39 +4585,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
|
4583 | 4585 | (void)attr->getParameterIndices();
|
4584 | 4586 | }
|
4585 | 4587 |
|
4586 |
| -/// Checks if original candidate and registered derivative match in terms |
4587 |
| -/// of static declaration. If the original candidate is a constructor or is |
4588 |
| -/// defined as static method then the registered derivative is expected to be static. |
4589 |
| -/// Otherwise the registered derivative should be an instance method. |
4590 |
| -/// \returns true if mismatch is found, false otherwise. |
4591 |
| -static bool checkStaticDeclMismatch(AbstractFunctionDecl *originalCandidate, |
4592 |
| - AbstractFunctionDecl *registered) { |
4593 |
| - return (isa<ConstructorDecl>(originalCandidate) || |
4594 |
| - originalCandidate->isStatic()) != registered->isStatic(); |
4595 |
| -} |
4596 |
| - |
4597 |
| -/// Produces diagnostics for mismatch in static/instance method declaration |
4598 |
| -/// between original candidate and registered derivative. |
4599 |
| -static void diagnoseStaticDeclMismatch(AbstractFunctionDecl *originalCandidate, |
4600 |
| - FuncDecl *registered) { |
4601 |
| - auto &diags = originalCandidate->getASTContext().Diags; |
4602 |
| - diags.diagnose( |
4603 |
| - registered->getNameLoc(), diag::autodiff_attr_static_decl_mismatch, |
4604 |
| - registered->getName(), registered->isStatic(), !registered->isStatic()); |
4605 |
| - diags.diagnose( |
4606 |
| - originalCandidate->getNameLoc(), diag::autodiff_attr_static_decl_original, |
4607 |
| - originalCandidate->getName(), |
4608 |
| - isa<ConstructorDecl>(originalCandidate) || originalCandidate->isStatic()); |
4609 |
| - auto fixItDiag = diags.diagnose( |
4610 |
| - registered->getStartLoc(), diag::autodiff_attr_static_decl_mismatch_fix, |
4611 |
| - registered->getName(), !registered->isStatic()); |
4612 |
| - if (registered->isStatic()) { |
4613 |
| - fixItDiag.fixItRemove(registered->getStaticLoc()); |
4614 |
| - } else { |
4615 |
| - fixItDiag.fixItInsert(registered->getStartLoc(), "static "); |
4616 |
| - } |
4617 |
| -} |
4618 |
| - |
4619 | 4588 | /// Type-checks the given `@derivative` attribute `attr` on declaration `D`.
|
4620 | 4589 | ///
|
4621 | 4590 | /// Effects are:
|
@@ -4842,12 +4811,39 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
|
4842 | 4811 | }
|
4843 | 4812 | }
|
4844 | 4813 |
|
| 4814 | + attr->setOriginalFunction(originalAFD); |
| 4815 | + |
| 4816 | + // Returns true if: |
| 4817 | + // - Original function and derivative function are static methods. |
| 4818 | + // - Original function and derivative function are non-static methods. |
| 4819 | + // - Original function is a Constructor declaration and derivative function is |
| 4820 | + // a static method. |
| 4821 | + auto compatibleStaticDecls = [&]() { |
| 4822 | + return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) == |
| 4823 | + derivative->isStatic(); |
| 4824 | + }; |
| 4825 | + |
4845 | 4826 | // Diagnose if original function and derivative differ in terms of static declaration.
|
4846 |
| - if (checkStaticDeclMismatch(originalAFD, derivative)) { |
4847 |
| - diagnoseStaticDeclMismatch(originalAFD, derivative); |
| 4827 | + if (!compatibleStaticDecls()) { |
| 4828 | + bool derivativeMustBeStatic = !derivative->isStatic(); |
| 4829 | + diags.diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(), |
| 4830 | + diag::derivative_attr_static_method_mismatch_original, |
| 4831 | + originalAFD->getName(), derivative->getName(), |
| 4832 | + derivativeMustBeStatic); |
| 4833 | + diags.diagnose(originalAFD->getNameLoc(), |
| 4834 | + diag::derivative_attr_static_method_mismatch_original_note, |
| 4835 | + originalAFD->getName(), derivativeMustBeStatic); |
| 4836 | + auto fixItDiag = |
| 4837 | + diags.diagnose(derivative->getStartLoc(), |
| 4838 | + diag::derivative_attr_static_method_mismatch_fix, |
| 4839 | + derivative->getName(), derivativeMustBeStatic); |
| 4840 | + if (derivativeMustBeStatic) { |
| 4841 | + fixItDiag.fixItInsert(derivative->getStartLoc(), "static "); |
| 4842 | + } else { |
| 4843 | + fixItDiag.fixItRemove(derivative->getStaticLoc()); |
| 4844 | + } |
4848 | 4845 | return true;
|
4849 | 4846 | }
|
4850 |
| - attr->setOriginalFunction(originalAFD); |
4851 | 4847 |
|
4852 | 4848 | // Returns true if:
|
4853 | 4849 | // - Original function and derivative function have the same access level.
|
@@ -5364,11 +5360,35 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
|
5364 | 5360 | return;
|
5365 | 5361 | }
|
5366 | 5362 |
|
| 5363 | + // Returns true if: |
| 5364 | + // - Original function and transpose function are static methods. |
| 5365 | + // - Original function and transpose function are non-static methods. |
| 5366 | + // - Original function is a Constructor declaration and transpose function is |
| 5367 | + // a static method. |
| 5368 | + auto compatibleStaticDecls = [&]() { |
| 5369 | + return (isa<ConstructorDecl>(originalAFD) || originalAFD->isStatic()) == |
| 5370 | + transpose->isStatic(); |
| 5371 | + }; |
| 5372 | + |
5367 | 5373 | // Diagnose if original function and transpose differ in terms of static declaration.
|
5368 |
| - if (!doSelfTypesMatch && checkStaticDeclMismatch(originalAFD, transpose)) { |
5369 |
| - diagnoseStaticDeclMismatch(originalAFD, transpose); |
5370 |
| - attr->setInvalid(); |
5371 |
| - return; |
| 5374 | + if (!doSelfTypesMatch && !compatibleStaticDecls()) { |
| 5375 | + bool transposeMustBeStatic = !transpose->isStatic(); |
| 5376 | + diagnose(attr->getOriginalFunctionName().Loc.getBaseNameLoc(), |
| 5377 | + diag::transpose_attr_static_method_mismatch_original, |
| 5378 | + originalAFD->getName(), transpose->getName(), |
| 5379 | + transposeMustBeStatic); |
| 5380 | + diagnose(originalAFD->getNameLoc(), |
| 5381 | + diag::transpose_attr_static_method_mismatch_original_note, |
| 5382 | + originalAFD->getName(), transposeMustBeStatic); |
| 5383 | + auto fixItDiag = diagnose(transpose->getStartLoc(), |
| 5384 | + diag::transpose_attr_static_method_mismatch_fix, |
| 5385 | + transpose->getName(), transposeMustBeStatic); |
| 5386 | + if (transposeMustBeStatic) { |
| 5387 | + fixItDiag.fixItInsert(transpose->getStartLoc(), "static "); |
| 5388 | + } else { |
| 5389 | + fixItDiag.fixItRemove(transpose->getStaticLoc()); |
| 5390 | + } |
| 5391 | + return; |
5372 | 5392 | }
|
5373 | 5393 |
|
5374 | 5394 | // Set the resolved linearity parameter indices in the attribute.
|
|
0 commit comments