Skip to content

Commit 637314a

Browse files
author
Marc Rasi
committed
[AutoDiff] factor derivative typechecking helper out of AttributeChecker
1 parent a50b940 commit 637314a

File tree

2 files changed

+90
-67
lines changed

2 files changed

+90
-67
lines changed

include/swift/AST/ASTContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class ASTContext final {
290290
// derivative function configurations per original `AbstractFunctionDecl`.
291291
llvm::DenseMap<
292292
std::tuple<Decl *, IndexSubset *, AutoDiffDerivativeFunctionKind>,
293-
DerivativeAttr *>
293+
llvm::SmallPtrSet<DerivativeAttr *, 1>>
294294
DerivativeAttrs;
295295

296296
private:

lib/Sema/TypeCheckAttr.cpp

Lines changed: 89 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3391,14 +3391,28 @@ getAutoDiffOriginalFunctionType(AnyFunctionType *derivativeFnTy) {
33913391
return originalType;
33923392
}
33933393

3394-
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3394+
/// Typechecks the given derivative attribute `attr` on decl `D`.
3395+
///
3396+
/// Effects are:
3397+
/// - Sets the original function and parameter indices on `attr`.
3398+
/// - Diagnoses errors.
3399+
/// - Stores the attribute in `ASTContext::DerivativeAttrs`.
3400+
///
3401+
/// \returns true on error, false on success.
3402+
static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
3403+
DerivativeAttr *attr) {
3404+
// Note: Implementation must be idempotent because it can get called multiple
3405+
// times for the same attribute.
3406+
3407+
auto &diags = Ctx.Diags;
3408+
33953409
// `@derivative` attribute requires experimental differentiable programming
33963410
// to be enabled.
33973411
auto &ctx = D->getASTContext();
33983412
if (!ctx.LangOpts.EnableExperimentalDifferentiableProgramming) {
3399-
diagnoseAndRemoveAttr(
3400-
attr, diag::experimental_differentiable_programming_disabled);
3401-
return;
3413+
diags.diagnose(attr->getLocation(),
3414+
diag::experimental_differentiable_programming_disabled);
3415+
return true;
34023416
}
34033417
auto *derivative = cast<FuncDecl>(D);
34043418
auto lookupConformance =
@@ -3418,26 +3432,27 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
34183432
auto derivativeResultTupleType = derivativeResultType->getAs<TupleType>();
34193433
if (!derivativeResultTupleType ||
34203434
derivativeResultTupleType->getNumElements() != 2) {
3421-
diagnoseAndRemoveAttr(attr, diag::derivative_attr_expected_result_tuple);
3422-
return;
3435+
diags.diagnose(attr->getLocation(),
3436+
diag::derivative_attr_expected_result_tuple);
3437+
return true;
34233438
}
34243439
auto valueResultElt = derivativeResultTupleType->getElement(0);
34253440
auto funcResultElt = derivativeResultTupleType->getElement(1);
34263441
// Get derivative kind and derivative function identifier.
34273442
AutoDiffDerivativeFunctionKind kind;
34283443
if (valueResultElt.getName().str() != "value") {
3429-
diagnoseAndRemoveAttr(
3430-
attr, diag::derivative_attr_invalid_result_tuple_value_label);
3431-
return;
3444+
diags.diagnose(attr->getLocation(),
3445+
diag::derivative_attr_invalid_result_tuple_value_label);
3446+
return true;
34323447
}
34333448
if (funcResultElt.getName().str() == "differential") {
34343449
kind = AutoDiffDerivativeFunctionKind::JVP;
34353450
} else if (funcResultElt.getName().str() == "pullback") {
34363451
kind = AutoDiffDerivativeFunctionKind::VJP;
34373452
} else {
3438-
diagnoseAndRemoveAttr(
3439-
attr, diag::derivative_attr_invalid_result_tuple_func_label);
3440-
return;
3453+
diags.diagnose(attr->getLocation(),
3454+
diag::derivative_attr_invalid_result_tuple_func_label);
3455+
return true;
34413456
}
34423457
attr->setDerivativeKind(kind);
34433458
// `value: R` result tuple element must conform to `Differentiable`.
@@ -3448,10 +3463,10 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
34483463
auto valueResultConf = TypeChecker::conformsToProtocol(
34493464
valueResultType, diffableProto, derivative->getDeclContext(), None);
34503465
if (!valueResultConf) {
3451-
diagnoseAndRemoveAttr(attr,
3452-
diag::derivative_attr_result_value_not_differentiable,
3453-
valueResultElt.getType());
3454-
return;
3466+
diags.diagnose(attr->getLocation(),
3467+
diag::derivative_attr_result_value_not_differentiable,
3468+
valueResultElt.getType());
3469+
return true;
34553470
}
34563471

34573472
// Compute expected original function type and look up original function.
@@ -3495,22 +3510,23 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
34953510
};
34963511

34973512
auto noneValidDiagnostic = [&]() {
3498-
diagnose(originalName.Loc,
3499-
diag::autodiff_attr_original_decl_none_valid_found,
3500-
originalName.Name, originalFnType);
3513+
diags.diagnose(originalName.Loc,
3514+
diag::autodiff_attr_original_decl_none_valid_found,
3515+
originalName.Name, originalFnType);
35013516
};
35023517
auto ambiguousDiagnostic = [&]() {
3503-
diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl,
3504-
originalName.Name, attr->getAttrName());
3518+
diags.diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl,
3519+
originalName.Name, attr->getAttrName());
35053520
};
35063521
auto notFunctionDiagnostic = [&]() {
3507-
diagnose(originalName.Loc, diag::autodiff_attr_original_decl_invalid_kind,
3508-
originalName.Name);
3522+
diags.diagnose(originalName.Loc,
3523+
diag::autodiff_attr_original_decl_invalid_kind,
3524+
originalName.Name);
35093525
};
35103526
std::function<void()> invalidTypeContextDiagnostic = [&]() {
3511-
diagnose(originalName.Loc,
3512-
diag::autodiff_attr_original_decl_not_same_type_context,
3513-
originalName.Name);
3527+
diags.diagnose(originalName.Loc,
3528+
diag::autodiff_attr_original_decl_not_same_type_context,
3529+
originalName.Name);
35143530
};
35153531

35163532
// Returns true if the derivative function and original function candidate are
@@ -3544,22 +3560,19 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
35443560
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
35453561
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
35463562
hasValidTypeContext, invalidTypeContextDiagnostic);
3547-
if (!originalAFD) {
3548-
attr->setInvalid();
3549-
return;
3550-
}
3563+
if (!originalAFD)
3564+
return true;
35513565
// Diagnose original stored properties. Stored properties cannot have custom
35523566
// registered derivatives.
35533567
if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
35543568
auto *asd = accessorDecl->getStorage();
35553569
if (asd->hasStorage()) {
3556-
diagnose(originalName.Loc,
3557-
diag::derivative_attr_original_stored_property_unsupported,
3558-
originalName.Name);
3559-
diagnose(originalAFD->getLoc(), diag::decl_declared_here,
3560-
asd->getFullName());
3561-
attr->setInvalid();
3562-
return;
3570+
diags.diagnose(originalName.Loc,
3571+
diag::derivative_attr_original_stored_property_unsupported,
3572+
originalName.Name);
3573+
diags.diagnose(originalAFD->getLoc(), diag::decl_declared_here,
3574+
asd->getFullName());
3575+
return true;
35633576
}
35643577
}
35653578
attr->setOriginalFunction(originalAFD);
@@ -3576,19 +3589,15 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
35763589
resolvedWrtParamIndices = computeDifferentiationParameters(
35773590
parsedWrtParams, derivative, derivative->getGenericEnvironment(),
35783591
attr->getAttrName(), attr->getLocation());
3579-
if (!resolvedWrtParamIndices) {
3580-
attr->setInvalid();
3581-
return;
3582-
}
3592+
if (!resolvedWrtParamIndices)
3593+
return true;
35833594

35843595
// Check if the `wrt:` parameter indices are valid.
35853596
if (checkDifferentiationParameters(
35863597
originalAFD, resolvedWrtParamIndices, originalFnType,
35873598
derivative->getGenericEnvironment(), derivative->getModuleContext(),
3588-
parsedWrtParams, attr->getLocation())) {
3589-
attr->setInvalid();
3590-
return;
3591-
}
3599+
parsedWrtParams, attr->getLocation()))
3600+
return true;
35923601

35933602
// Set the resolved `wrt:` parameter indices in the attribute.
35943603
attr->setParameterIndices(resolvedWrtParamIndices);
@@ -3647,42 +3656,56 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36473656
// Check if differential/pullback type matches expected type.
36483657
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
36493658
// Emit differential/pullback type mismatch error on attribute.
3650-
diagnoseAndRemoveAttr(attr, diag::derivative_attr_result_func_type_mismatch,
3651-
funcResultElt.getName(), originalAFD->getFullName());
3659+
diags.diagnose(attr->getLocation(),
3660+
diag::derivative_attr_result_func_type_mismatch,
3661+
funcResultElt.getName(), originalAFD->getFullName());
36523662
// Emit note with expected differential/pullback type on actual type
36533663
// location.
36543664
auto *tupleReturnTypeRepr =
36553665
cast<TupleTypeRepr>(derivative->getBodyResultTypeLoc().getTypeRepr());
36563666
auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType(1);
3657-
diagnose(funcEltTypeRepr->getStartLoc(),
3658-
diag::derivative_attr_result_func_type_mismatch_note,
3659-
funcResultElt.getName(), expectedFuncEltType)
3667+
diags
3668+
.diagnose(funcEltTypeRepr->getStartLoc(),
3669+
diag::derivative_attr_result_func_type_mismatch_note,
3670+
funcResultElt.getName(), expectedFuncEltType)
36603671
.highlight(funcEltTypeRepr->getSourceRange());
36613672
// Emit note showing original function location, if possible.
36623673
if (originalAFD->getLoc().isValid())
3663-
diagnose(originalAFD->getLoc(),
3664-
diag::derivative_attr_result_func_original_note,
3665-
originalAFD->getFullName());
3666-
return;
3674+
diags.diagnose(originalAFD->getLoc(),
3675+
diag::derivative_attr_result_func_original_note,
3676+
originalAFD->getFullName());
3677+
return true;
36673678
}
36683679

36693680
// Reject different-file derivative registration.
36703681
// TODO(TF-1021): Lift this restriction.
36713682
if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) {
3672-
diagnoseAndRemoveAttr(attr,
3673-
diag::derivative_attr_not_in_same_file_as_original);
3674-
return;
3683+
diags.diagnose(attr->getLocation(),
3684+
diag::derivative_attr_not_in_same_file_as_original);
3685+
return true;
36753686
}
36763687

36773688
// Reject duplicate `@derivative` attributes.
3678-
auto insertion = Ctx.DerivativeAttrs.try_emplace(
3679-
std::make_tuple(originalAFD, resolvedWrtParamIndices, kind), attr);
3680-
if (!insertion.second) {
3681-
diagnoseAndRemoveAttr(attr,
3682-
diag::derivative_attr_original_already_has_derivative,
3683-
originalAFD->getFullName());
3684-
diagnose(insertion.first->getSecond()->getLocation(),
3685-
diag::derivative_attr_duplicate_note);
3686-
return;
3689+
auto &derivativeAttrs = Ctx.DerivativeAttrs[std::make_tuple(
3690+
originalAFD, resolvedWrtParamIndices, kind)];
3691+
derivativeAttrs.insert(attr);
3692+
if (derivativeAttrs.size() > 1) {
3693+
diags.diagnose(attr->getLocation(),
3694+
diag::derivative_attr_original_already_has_derivative,
3695+
originalAFD->getFullName());
3696+
for (auto *duplicateAttr : derivativeAttrs) {
3697+
if (duplicateAttr == attr)
3698+
continue;
3699+
diags.diagnose(duplicateAttr->getLocation(),
3700+
diag::derivative_attr_duplicate_note);
3701+
}
3702+
return true;
36873703
}
3704+
3705+
return false;
3706+
}
3707+
3708+
void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
3709+
if (typeCheckDerivativeAttr(Ctx, D, attr))
3710+
attr->setInvalid();
36883711
}

0 commit comments

Comments
 (0)