Skip to content

Commit 569b24c

Browse files
author
Marc Rasi
committed
typecheck all derivative attributes in the module
1 parent 8dfd83b commit 569b24c

File tree

9 files changed

+136
-90
lines changed

9 files changed

+136
-90
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ class ASTContext final {
311311
// same parameter indices but different derivative generic signatures.
312312
llvm::DenseMap<
313313
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
314-
DerivativeAttr *>
314+
llvm::SmallPtrSet<DerivativeAttr *, 1>>
315315
DerivativeAttrs;
316316
// SWIFT_ENABLE_TENSORFLOW END
317317

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,10 +734,6 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
734734
if (existingWitness)
735735
return existingWitness;
736736

737-
assert(original->isExternalDeclaration() &&
738-
"SILGen should create differentiability witnesses for all function "
739-
"definitions with explicit differentiable attributes");
740-
741737
return SILDifferentiabilityWitness::createDeclaration(
742738
module, SILLinkage::PublicExternal, original,
743739
minimalConfig->parameterIndices, minimalConfig->resultIndices,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 113 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "TypeCheckType.h"
2020
#include "TypeChecker.h"
2121
#include "swift/AST/ASTVisitor.h"
22+
#include "swift/AST/ASTWalker.h"
2223
#include "swift/AST/ClangModuleLoader.h"
2324
#include "swift/AST/DiagnosticsParse.h"
2425
#include "swift/AST/GenericEnvironment.h"
@@ -3592,7 +3593,24 @@ DifferentiableAttributeParameterIndicesRequest::evaluate(
35923593
}
35933594

35943595
// SWIFT_ENABLE_TENSORFLOW
3595-
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3596+
/// Typechecks the given derivative attribute `attr` on decl `D`.
3597+
///
3598+
/// Effects are:
3599+
/// - Sets the parameter indices on `attr`.
3600+
/// - Diagnoses errors.
3601+
/// - Stores the attribute in the `ASTContext` list of derivative attributes.
3602+
/// - Stores the derivative configuration in the original function's list of
3603+
/// derivative configurations.
3604+
///
3605+
/// \returns true on error, false on success.
3606+
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
3607+
DerivativeAttr *attr) {
3608+
3609+
// Note: Implementation must be idempotent because it can get called multiple
3610+
// times for the same attribute.
3611+
3612+
auto &diags = Ctx.Diags;
3613+
35963614
FuncDecl *derivative = cast<FuncDecl>(D);
35973615
auto lookupConformance =
35983616
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
@@ -3612,29 +3630,27 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36123630
auto derivativeResultTupleType = derivativeResultType->getAs<TupleType>();
36133631
if (!derivativeResultTupleType ||
36143632
derivativeResultTupleType->getNumElements() != 2) {
3615-
diagnose(attr->getLocation(), diag::derivative_attr_expected_result_tuple);
3616-
attr->setInvalid();
3617-
return;
3633+
diags.diagnose(attr->getLocation(),
3634+
diag::derivative_attr_expected_result_tuple);
3635+
return true;
36183636
}
36193637
auto valueResultElt = derivativeResultTupleType->getElement(0);
36203638
auto funcResultElt = derivativeResultTupleType->getElement(1);
36213639
// Get derivative kind and derivative function identifier.
36223640
AutoDiffDerivativeFunctionKind kind;
36233641
if (valueResultElt.getName().str() != "value") {
3624-
diagnose(attr->getLocation(),
3625-
diag::derivative_attr_invalid_result_tuple_value_label);
3626-
attr->setInvalid();
3627-
return;
3642+
diags.diagnose(attr->getLocation(),
3643+
diag::derivative_attr_invalid_result_tuple_value_label);
3644+
return true;
36283645
}
36293646
if (funcResultElt.getName().str() == "differential") {
36303647
kind = AutoDiffDerivativeFunctionKind::JVP;
36313648
} else if (funcResultElt.getName().str() == "pullback") {
36323649
kind = AutoDiffDerivativeFunctionKind::VJP;
36333650
} else {
3634-
diagnose(attr->getLocation(),
3635-
diag::derivative_attr_invalid_result_tuple_func_label);
3636-
attr->setInvalid();
3637-
return;
3651+
diags.diagnose(attr->getLocation(),
3652+
diag::derivative_attr_invalid_result_tuple_func_label);
3653+
return true;
36383654
}
36393655
attr->setDerivativeKind(kind);
36403656
// `value: R` result tuple element must conform to `Differentiable`.
@@ -3645,11 +3661,10 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36453661
auto valueResultConf = TypeChecker::conformsToProtocol(
36463662
valueResultType, diffableProto, derivative->getDeclContext(), None);
36473663
if (!valueResultConf) {
3648-
diagnose(attr->getLocation(),
3649-
diag::derivative_attr_result_value_not_differentiable,
3650-
valueResultElt.getType());
3651-
attr->setInvalid();
3652-
return;
3664+
diags.diagnose(attr->getLocation(),
3665+
diag::derivative_attr_result_value_not_differentiable,
3666+
valueResultElt.getType());
3667+
return true;
36533668
}
36543669

36553670
// Compute expected original function type and look up original function.
@@ -3693,23 +3708,23 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36933708
};
36943709

36953710
auto noneValidDiagnostic = [&]() {
3696-
diagnose(originalName.Loc,
3697-
diag::autodiff_attr_original_decl_none_valid_found,
3698-
originalName.Name, originalFnType);
3711+
diags.diagnose(originalName.Loc,
3712+
diag::autodiff_attr_original_decl_none_valid_found,
3713+
originalName.Name, originalFnType);
36993714
};
37003715
auto ambiguousDiagnostic = [&]() {
3701-
diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl,
3702-
originalName.Name, attr->getAttrName());
3716+
diags.diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl,
3717+
originalName.Name, attr->getAttrName());
37033718
};
37043719
auto notFunctionDiagnostic = [&]() {
3705-
diagnose(originalName.Loc,
3706-
diag::autodiff_attr_original_decl_invalid_kind,
3707-
originalName.Name);
3720+
diags.diagnose(originalName.Loc,
3721+
diag::autodiff_attr_original_decl_invalid_kind,
3722+
originalName.Name);
37083723
};
37093724
std::function<void()> invalidTypeContextDiagnostic = [&]() {
3710-
diagnose(originalName.Loc,
3711-
diag::autodiff_attr_original_decl_not_same_type_context,
3712-
originalName.Name);
3725+
diags.diagnose(originalName.Loc,
3726+
diag::autodiff_attr_original_decl_not_same_type_context,
3727+
originalName.Name);
37133728
};
37143729

37153730
// Returns true if the derivative function and original function candidate are
@@ -3743,52 +3758,39 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
37433758
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
37443759
hasValidTypeContext, invalidTypeContextDiagnostic);
37453760
if (!originalAFD) {
3746-
attr->setInvalid();
3747-
return;
3761+
return true;
37483762
}
37493763
// Diagnose original stored properties. Stored properties cannot have custom
37503764
// registered derivatives.
37513765
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
37523766
auto *asd = accessorDecl->getStorage();
37533767
if (asd->hasStorage()) {
3754-
diagnose(originalName.Loc,
3755-
diag::derivative_attr_original_stored_property_unsupported,
3756-
originalName.Name);
3757-
diagnose(originalAFD->getLoc(), diag::decl_declared_here,
3758-
asd->getFullName());
3759-
attr->setInvalid();
3760-
return;
3768+
diags.diagnose(originalName.Loc,
3769+
diag::derivative_attr_original_stored_property_unsupported,
3770+
originalName.Name);
3771+
diags.diagnose(originalAFD->getLoc(), diag::decl_declared_here,
3772+
asd->getFullName());
3773+
return true;
37613774
}
37623775
}
37633776
attr->setOriginalFunction(originalAFD);
37643777

3765-
// Get checked wrt param indices.
3766-
auto *checkedWrtParamIndices = attr->getParameterIndices();
3767-
37683778
// Get the parsed wrt param indices, which have not yet been checked.
37693779
// This is defined for parsed attributes.
37703780
auto parsedWrtParams = attr->getParsedParameters();
37713781

3772-
// If checked wrt param indices are not specified, compute them.
3782+
auto *checkedWrtParamIndices = computeDifferentiationParameters(
3783+
parsedWrtParams, derivative, derivative->getGenericEnvironment(),
3784+
attr->getAttrName(), attr->getLocation());
37733785
if (!checkedWrtParamIndices)
3774-
checkedWrtParamIndices =
3775-
computeDifferentiationParameters(parsedWrtParams, derivative,
3776-
derivative->getGenericEnvironment(),
3777-
attr->getAttrName(),
3778-
attr->getLocation());
3779-
if (!checkedWrtParamIndices) {
3780-
attr->setInvalid();
3781-
return;
3782-
}
3786+
return true;
37833787

37843788
// Check if differentiation parameter indices are valid.
37853789
if (checkDifferentiationParameters(
37863790
originalAFD, checkedWrtParamIndices, originalFnType,
37873791
derivative->getGenericEnvironment(), derivative->getModuleContext(),
3788-
parsedWrtParams, attr->getLocation())) {
3789-
attr->setInvalid();
3790-
return;
3791-
}
3792+
parsedWrtParams, attr->getLocation()))
3793+
return true;
37923794

37933795
// Set the checked differentiation parameter indices in the attribute.
37943796
attr->setParameterIndices(checkedWrtParamIndices);
@@ -3846,25 +3848,25 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
38463848
// Check if differential/pullback type matches expected type.
38473849
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
38483850
// Emit differential/pullback type mismatch error on attribute.
3849-
diagnose(attr->getLocation(),
3850-
diag::derivative_attr_result_func_type_mismatch,
3851-
funcResultElt.getName(), originalAFD->getFullName());
3851+
diags.diagnose(attr->getLocation(),
3852+
diag::derivative_attr_result_func_type_mismatch,
3853+
funcResultElt.getName(), originalAFD->getFullName());
38523854
// Emit note with expected differential/pullback type on actual type
38533855
// location.
38543856
auto *tupleReturnTypeRepr =
38553857
cast<TupleTypeRepr>(derivative->getBodyResultTypeLoc().getTypeRepr());
38563858
auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType(1);
3857-
diagnose(funcEltTypeRepr->getStartLoc(),
3858-
diag::derivative_attr_result_func_type_mismatch_note,
3859-
funcResultElt.getName(), expectedFuncEltType)
3859+
diags
3860+
.diagnose(funcEltTypeRepr->getStartLoc(),
3861+
diag::derivative_attr_result_func_type_mismatch_note,
3862+
funcResultElt.getName(), expectedFuncEltType)
38603863
.highlight(funcEltTypeRepr->getSourceRange());
38613864
// Emit note showing original function location, if possible.
38623865
if (originalAFD->getLoc().isValid())
3863-
diagnose(originalAFD->getLoc(),
3864-
diag::derivative_attr_result_func_original_note,
3865-
originalAFD->getFullName());
3866-
attr->setInvalid();
3867-
return;
3866+
diags.diagnose(originalAFD->getLoc(),
3867+
diag::derivative_attr_result_func_original_note,
3868+
originalAFD->getFullName());
3869+
return true;
38683870
}
38693871

38703872
// Check that derivative visibility is at least as restricted as original
@@ -3873,29 +3875,42 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
38733875
originalAFD->getFormalAccessScope() &&
38743876
!derivative->getFormalAccessScope().isChildOf(
38753877
originalAFD->getFormalAccessScope())) {
3876-
diagnoseAndRemoveAttr(attr, diag::derivative_attr_visibility_too_broad);
3877-
diagnose(originalAFD->getLoc(),
3878-
diag::derivative_attr_visibility_too_broad_note);
3879-
return;
3878+
diags.diagnose(attr->getLocation(),
3879+
diag::derivative_attr_visibility_too_broad);
3880+
diags.diagnose(originalAFD->getLoc(),
3881+
diag::derivative_attr_visibility_too_broad_note);
3882+
return true;
38803883
}
38813884

38823885
// Reject duplicate `@derivative` attributes.
3883-
auto insertion = Ctx.DerivativeAttrs.try_emplace(
3884-
{originalAFD, checkedWrtParamIndices, kind}, attr);
3885-
if (!insertion.second) {
3886-
diagnoseAndRemoveAttr(attr,
3887-
diag::derivative_attr_original_already_has_derivative,
3888-
originalAFD->getFullName());
3889-
diagnose(insertion.first->getSecond()->getLocation(),
3890-
diag::differentiable_attr_duplicate_note);
3891-
return;
3886+
auto &derivativeAttrs =
3887+
Ctx.DerivativeAttrs[{originalAFD, checkedWrtParamIndices, kind}];
3888+
derivativeAttrs.insert(attr);
3889+
if (derivativeAttrs.size() > 1) {
3890+
diags.diagnose(attr->getLocation(),
3891+
diag::derivative_attr_original_already_has_derivative,
3892+
originalAFD->getFullName());
3893+
for (auto *duplicateAttr : derivativeAttrs) {
3894+
if (duplicateAttr == attr)
3895+
continue;
3896+
diags.diagnose(duplicateAttr->getLocation(),
3897+
diag::differentiable_attr_duplicate_note);
3898+
}
3899+
return true;
38923900
}
38933901

38943902
// Register derivative function configuration.
38953903
auto *resultIndices = IndexSubset::get(Ctx, 1, {0});
38963904
originalAFD->addDerivativeFunctionConfiguration(
38973905
{checkedWrtParamIndices, resultIndices,
38983906
derivative->getGenericSignature()});
3907+
3908+
return false;
3909+
}
3910+
3911+
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3912+
if (typeCheckDerivativeAttr(Ctx, D, attr))
3913+
attr->setInvalid();
38993914
}
39003915

39013916
void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
@@ -4527,3 +4542,23 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
45274542

45284543
return nullptr;
45294544
}
4545+
4546+
// SWIFT_ENABLE_TENSORFLOW
4547+
void TypeChecker::typeCheckDerivativeAttrs(SourceFile &sourceFile) {
4548+
class DerivativeAttrBindingWalker : public ASTWalker {
4549+
bool walkToDeclPre(Decl *decl) override {
4550+
auto f = dyn_cast<AbstractFunctionDecl>(decl);
4551+
if (!f)
4552+
return true;
4553+
for (auto *attr : f->getAttrs())
4554+
if (auto *da = dyn_cast<DerivativeAttr>(attr))
4555+
typeCheckDerivativeAttr(f->getASTContext(), f, da);
4556+
return true;
4557+
}
4558+
};
4559+
4560+
DiagnosticTransaction diagTxn(sourceFile.getASTContext().Diags);
4561+
DerivativeAttrBindingWalker walker;
4562+
sourceFile.getParentModule()->walk(walker);
4563+
diagTxn.abort();
4564+
}

lib/Sema/TypeChecker.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,16 @@ TypeCheckSourceFileRequest::evaluate(Evaluator &eval,
385385
// to work.
386386
::bindExtensions(*SF);
387387

388+
// SWIFT_ENABLE_TENSORFLOW
389+
// Type check all `@derivative` attributes in the module. Later, attribute
390+
// checking re-checks all `@derivative` attributes in the primary file(s).
391+
// This initial checking pass must occur before the re-checking, so that
392+
// re-checking can diagnose duplicate attributes using information about all
393+
// the derivatives in the module. The differentiation pass also relies on
394+
// the information collected by this initial checking pass to see
395+
// derivatives defined in the whole module.
396+
TypeChecker::typeCheckDerivativeAttrs(*SF);
397+
388398
// Type check the top-level elements of the source file.
389399
for (auto D : llvm::makeArrayRef(SF->Decls).slice(StartElem)) {
390400
if (auto *TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
@@ -405,9 +415,8 @@ TypeCheckSourceFileRequest::evaluate(Evaluator &eval,
405415
}
406416

407417
// Checking that benefits from having the whole module available.
408-
if (!Ctx.TypeCheckerOpts.DelayWholeModuleChecking) {
418+
if (!Ctx.TypeCheckerOpts.DelayWholeModuleChecking)
409419
performWholeModuleTypeChecking(*SF);
410-
}
411420

412421
return true;
413422
}

lib/Sema/TypeChecker.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,10 @@ class TypeChecker final {
15951595
// SWIFT_ENABLE_TENSORFLOW
15961596
void checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D);
15971597

1598+
/// Typechecks all derivative attributes in the module that can affect
1599+
/// derivative configurations of functions in `sourceFile`.
1600+
static void typeCheckDerivativeAttrs(SourceFile &sourceFile);
1601+
15981602
/// If an expression references 'self.init' or 'super.init' in an
15991603
/// initializer context, returns the implicit 'self' decl of the constructor.
16001604
/// Otherwise, return nil.

lib/TBDGen/TBDGen.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ void TBDGenVisitor::addSymbol(StringRef name, SymbolKind kind) {
6767
if (StringSymbols && kind == SymbolKind::GlobalSymbol) {
6868
auto isNewValue = StringSymbols->insert(mangled).second;
6969
(void)isNewValue;
70-
if (!isNewValue)
71-
llvm::dbgs() << mangled << "\n";
7270
assert(isNewValue && "symbol appears twice");
7371
}
7472
}

test/AutoDiff/Inputs/derivative_attr_type_checking/main/main.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ func sin(_ x: Float) -> Float {
1111
func jvpSin(x: @nondiff Float) -> (value: Float, differential: (Float) -> (Float)) {
1212
return (x, { $0 })
1313
}
14+
// expected-error @+2 {{a derivative already exists for 'sin'}}
1415
// expected-note @+1 {{other attribute declared here}}
1516
@derivative(of: sin, wrt: x) // ok
1617
func vjpSinExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) {
1718
return (x, { $0 })
1819
}
1920

20-
// expected-error @+1 {{a derivative already exists for 'sin'}}
21+
// expected-error @+2 {{a derivative already exists for 'sin'}}
22+
// expected-note @+1 {{other attribute declared here}}
2123
@derivative(of: sin)
2224
func vjpDuplicate(x: Float) -> (value: Float, pullback: (Float) -> Float) {
2325
return (x, { $0 })
@@ -554,6 +556,7 @@ func dDerivativesHaveDifferentAccessLevels2(_ x: Float) -> (value: Float, pullba
554556

555557
// Check that cross-file and cross-module duplicate derivatives are rejected.
556558

559+
// expected-error @+2 {{a derivative already exists for 'functionDefinedInOtherFile_publicDerivativeInOtherFile'}}
557560
// expected-note @+1 {{other attribute declared here}}
558561
@derivative(of: functionDefinedInOtherFile_publicDerivativeInOtherFile)
559562
public func crossFileDuplicateDerivative2(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {

0 commit comments

Comments
 (0)