Skip to content

Commit 2a63ab7

Browse files
committed
Use AttributeChecker::diagnoseAndRemoveAttr consistently.
1 parent 865c4ea commit 2a63ab7

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3371,7 +3371,6 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
33713371
derivative->getInterfaceType()->castTo<AnyFunctionType>();
33723372

33733373
// Perform preliminary `@derivative` declaration checks.
3374-
33753374
// The result type should be a two-element tuple.
33763375
// Either a value and pullback:
33773376
// (value: R, pullback: (R.TangentVector) -> (T.TangentVector...)
@@ -3381,28 +3380,25 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
33813380
auto derivativeResultTupleType = derivativeResultType->getAs<TupleType>();
33823381
if (!derivativeResultTupleType ||
33833382
derivativeResultTupleType->getNumElements() != 2) {
3384-
diagnose(attr->getLocation(), diag::derivative_attr_expected_result_tuple);
3385-
attr->setInvalid();
3383+
diagnoseAndRemoveAttr(attr, diag::derivative_attr_expected_result_tuple);
33863384
return;
33873385
}
33883386
auto valueResultElt = derivativeResultTupleType->getElement(0);
33893387
auto funcResultElt = derivativeResultTupleType->getElement(1);
33903388
// Get derivative kind and derivative function identifier.
33913389
AutoDiffDerivativeFunctionKind kind;
33923390
if (valueResultElt.getName().str() != "value") {
3393-
diagnose(attr->getLocation(),
3394-
diag::derivative_attr_invalid_result_tuple_value_label);
3395-
attr->setInvalid();
3391+
diagnoseAndRemoveAttr(
3392+
attr, diag::derivative_attr_invalid_result_tuple_value_label);
33963393
return;
33973394
}
33983395
if (funcResultElt.getName().str() == "differential") {
33993396
kind = AutoDiffDerivativeFunctionKind::JVP;
34003397
} else if (funcResultElt.getName().str() == "pullback") {
34013398
kind = AutoDiffDerivativeFunctionKind::VJP;
34023399
} else {
3403-
diagnose(attr->getLocation(),
3404-
diag::derivative_attr_invalid_result_tuple_func_label);
3405-
attr->setInvalid();
3400+
diagnoseAndRemoveAttr(
3401+
attr, diag::derivative_attr_invalid_result_tuple_func_label);
34063402
return;
34073403
}
34083404
attr->setDerivativeKind(kind);
@@ -3414,10 +3410,9 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
34143410
auto valueResultConf = TypeChecker::conformsToProtocol(
34153411
valueResultType, diffableProto, derivative->getDeclContext(), None);
34163412
if (!valueResultConf) {
3417-
diagnose(attr->getLocation(),
3418-
diag::derivative_attr_result_value_not_differentiable,
3419-
valueResultElt.getType());
3420-
attr->setInvalid();
3413+
diagnoseAndRemoveAttr(attr,
3414+
diag::derivative_attr_result_value_not_differentiable,
3415+
valueResultElt.getType());
34213416
return;
34223417
}
34233418

@@ -3602,9 +3597,8 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36023597
// Check if differential/pullback type matches expected type.
36033598
if (!funcEltType->isEqual(expectedFuncEltType)) {
36043599
// Emit differential/pullback type mismatch error on attribute.
3605-
diagnose(attr->getLocation(),
3606-
diag::derivative_attr_result_func_type_mismatch,
3607-
funcResultElt.getName(), originalAFD->getFullName());
3600+
diagnoseAndRemoveAttr(attr, diag::derivative_attr_result_func_type_mismatch,
3601+
funcResultElt.getName(), originalAFD->getFullName());
36083602
// Emit note with expected differential/pullback type on actual type
36093603
// location.
36103604
auto *tupleReturnTypeRepr =
@@ -3619,7 +3613,6 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36193613
diagnose(originalAFD->getLoc(),
36203614
diag::derivative_attr_result_func_original_note,
36213615
originalAFD->getFullName());
3622-
attr->setInvalid();
36233616
return;
36243617
}
36253618

0 commit comments

Comments
 (0)