@@ -3053,15 +3053,6 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
3053
3053
// Set the checked differentiation parameter indices in the attribute.
3054
3054
attr->setParameterIndices (checkedWrtParamIndices);
3055
3055
3056
- auto insertion =
3057
- ctx.DifferentiableAttrs .try_emplace ({D, checkedWrtParamIndices}, attr);
3058
- // `@differentiable` attributes are uniqued by their parameter indices.
3059
- // Reject duplicate attributes for the same decl and parameter indices pair.
3060
- if (!insertion.second && insertion.first ->getSecond () != attr) {
3061
- diagnoseAndRemoveAttr (attr, diag::differentiable_attr_duplicate);
3062
- return ;
3063
- }
3064
-
3065
3056
// Check that original function's result type conforms to `Differentiable`.
3066
3057
if (whereClauseGenEnv) {
3067
3058
auto originalResultInterfaceType = !originalResultTy->hasTypeParameter ()
@@ -3079,9 +3070,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
3079
3070
}
3080
3071
3081
3072
// Checks that the `candidate` function type equals the `required` function
3082
- // type, disregarding parameter labels.
3083
- //
3084
- // Precondition: `required` has no parameter labels.
3073
+ // type, disregarding parameter labels and tuple result labels.
3085
3074
std::function<bool (CanAnyFunctionType, CanType)> checkFunctionSignature;
3086
3075
checkFunctionSignature = [&](CanAnyFunctionType required,
3087
3076
CanType candidate) -> bool {
@@ -3096,21 +3085,31 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
3096
3085
required.getOptGenericSignature ())
3097
3086
return false ;
3098
3087
3099
- // Check that parameter types match (disregards labels).
3100
- if (candidateFnTy.getParams ().size () != required.getParams ().size ())
3088
+ // Check that parameter types match, disregarding labels.
3089
+ if (!std::equal (required.getParams ().begin (), required.getParams ().end (),
3090
+ candidateFnTy.getParams ().begin (),
3091
+ [](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3092
+ return x.getPlainType ()->isEqual (y.getPlainType ());
3093
+ }))
3101
3094
return false ;
3102
- for (auto paramPair : llvm::zip (candidateFnTy.getParams (),
3103
- required.getParams ()))
3104
- if (!std::get<0 >(paramPair).getPlainType ()->isEqual (
3105
- std::get<1 >(paramPair).getPlainType ()))
3106
- return false ;
3107
3095
3108
- // If required result type is non-function, check that result types match
3109
- // exactly .
3096
+ // If required result type is non-function, check that result types match.
3097
+ // If result types are tuple types, ignore labels .
3110
3098
CanAnyFunctionType requiredResultFnTy =
3111
3099
dyn_cast<AnyFunctionType>(required.getResult ());
3112
- if (!requiredResultFnTy)
3113
- return required.getResult () == candidateFnTy.getResult ();
3100
+ if (!requiredResultFnTy) {
3101
+ auto requiredResultTupleTy = required.getResult ()->getAs <TupleType>();
3102
+ auto candidateResultTupleTy =
3103
+ candidateFnTy.getResult ()->getAs <TupleType>();
3104
+ if (!requiredResultTupleTy || !candidateResultTupleTy)
3105
+ return required.getResult ()->isEqual (candidateFnTy.getResult ());
3106
+ // If result types are tuple types, check that element types match,
3107
+ // ignoring labels.
3108
+ return std::equal (requiredResultTupleTy->getElementTypes ().begin (),
3109
+ requiredResultTupleTy->getElementTypes ().end (),
3110
+ candidateResultTupleTy->getElementTypes ().begin (),
3111
+ [](Type x, Type y) { return x->isEqual (y); });
3112
+ }
3114
3113
3115
3114
// Required result type is a function. Recurse.
3116
3115
return checkFunctionSignature (requiredResultFnTy,
@@ -3168,6 +3167,15 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
3168
3167
// Memorize the vjp reference in the attribute.
3169
3168
attr->setVJPFunction (vjp);
3170
3169
}
3170
+
3171
+ auto insertion =
3172
+ ctx.DifferentiableAttrs .try_emplace ({D, checkedWrtParamIndices}, attr);
3173
+ // `@differentiable` attributes are uniqued by their parameter indices.
3174
+ // Reject duplicate attributes for the same decl and parameter indices pair.
3175
+ if (!insertion.second && insertion.first ->getSecond () != attr) {
3176
+ diagnoseAndRemoveAttr (attr, diag::differentiable_attr_duplicate);
3177
+ return ;
3178
+ }
3171
3179
}
3172
3180
3173
3181
// SWIFT_ENABLE_TENSORFLOW
@@ -3489,33 +3497,50 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
3489
3497
derivativeRequirements.push_back (req);
3490
3498
}
3491
3499
3492
- // Add the derivative to a `@differentiable` attribute on the original
3493
- // function with the same differentiation parameters. If no such
3494
- // `@differentiable` attribute exists, create one.
3500
+ // Try to find a `@differentiable` attribute on the original function with the
3501
+ // same differentiation parameters.
3495
3502
DifferentiableAttr *da = nullptr ;
3496
3503
for (auto *cda : originalFn->getAttrs ().getAttributes <DifferentiableAttr>())
3497
3504
if (checkedWrtParamIndices == cda->getParameterIndices ())
3498
3505
da = const_cast <DifferentiableAttr *>(cda);
3506
+ // If the original function does not have a `@differentiable` attribute with
3507
+ // the same differentiation parameters, create one.
3499
3508
if (!da) {
3500
3509
da = DifferentiableAttr::create (ctx, /* implicit*/ true , attr->AtLoc ,
3501
3510
attr->getRange (), checkedWrtParamIndices,
3502
- None, None, derivativeRequirements);
3511
+ /* jvp*/ None, /* vjp*/ None,
3512
+ derivativeRequirements);
3513
+ switch (kind) {
3514
+ case AutoDiffAssociatedFunctionKind::JVP:
3515
+ da->setJVPFunction (derivative);
3516
+ break ;
3517
+ case AutoDiffAssociatedFunctionKind::VJP:
3518
+ da->setVJPFunction (derivative);
3519
+ break ;
3520
+ }
3503
3521
auto insertion = ctx.DifferentiableAttrs .try_emplace (
3504
3522
{originalFn, checkedWrtParamIndices}, da);
3505
- // `@differentiable` attributes are uniqued by their parameter indices.
3506
- // Reject duplicate attributes for the same decl and parameter indices pair.
3523
+ // Valid `@differentiable` attributes are uniqued by their parameter
3524
+ // indices. Reject duplicate attributes for the same decl and parameter
3525
+ // indices pair.
3507
3526
if (!insertion.second && insertion.first ->getSecond () != da) {
3508
3527
diagnoseAndRemoveAttr (da, diag::differentiable_attr_duplicate);
3509
3528
return ;
3510
3529
}
3511
3530
originalFn->getAttrs ().add (da);
3531
+ return ;
3512
3532
}
3513
- // Check if the `@differentiable` attribute already has a registered
3514
- // derivative. If so, emit an error on the `@differentiating` attribute.
3515
- // Otherwise, register the derivative in the `@differentiable` attribute.
3533
+ // If the original function has a `@differentiable` attribute with the same
3534
+ // differentiation parameters, check if the `@differentiable` attribute
3535
+ // already has a different registered derivative. If so, emit an error on the
3536
+ // `@differentiating` attribute. Otherwise, register the derivative in the
3537
+ // `@differentiable` attribute.
3516
3538
switch (kind) {
3517
3539
case AutoDiffAssociatedFunctionKind::JVP:
3518
- if (da->getJVP () || da->getJVPFunction ()) {
3540
+ // If there's a different registered derivative, emit an error.
3541
+ if ((da->getJVP () &&
3542
+ da->getJVP ()->Name .getBaseName () != derivative->getBaseName ()) ||
3543
+ (da->getJVPFunction () && da->getJVPFunction () != derivative)) {
3519
3544
diagnoseAndRemoveAttr (
3520
3545
attr, diag::differentiating_attr_original_already_has_derivative,
3521
3546
originalFn->getFullName ());
@@ -3524,7 +3549,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
3524
3549
da->setJVPFunction (derivative);
3525
3550
break ;
3526
3551
case AutoDiffAssociatedFunctionKind::VJP:
3527
- if (da->getVJP () || da->getVJPFunction ()) {
3552
+ // If there's a different registered derivative, emit an error.
3553
+ if ((da->getVJP () &&
3554
+ da->getVJP ()->Name .getBaseName () != derivative->getBaseName ()) ||
3555
+ (da->getVJPFunction () && da->getVJPFunction () != derivative)) {
3528
3556
diagnoseAndRemoveAttr (
3529
3557
attr, diag::differentiating_attr_original_already_has_derivative,
3530
3558
originalFn->getFullName ());
0 commit comments