@@ -3375,42 +3375,20 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
3375
3375
return nullptr ;
3376
3376
}
3377
3377
3378
- // / If the given type conforms to `Differentiable` in the given context, returns
3379
- // / the `ProtocolConformanceRef`. Otherwise, returns an invalid
3380
- // / `ProtocolConformanceRef`.
3381
- // /
3382
- // / This helper verifies that the `TangentVector` type witness is valid, in case
3383
- // / the conformance has not been fully checked and the type witness cannot be
3384
- // / resolved.
3385
- static ProtocolConformanceRef getDifferentiableConformance (Type type,
3386
- DeclContext *DC) {
3387
- auto &ctx = type->getASTContext ();
3388
- auto *differentiableProto =
3389
- ctx.getProtocol (KnownProtocolKind::Differentiable);
3390
- auto conf =
3391
- TypeChecker::conformsToProtocol (type, differentiableProto, DC);
3392
- if (!conf)
3393
- return ProtocolConformanceRef ();
3394
- // Try to get the `TangentVector` type witness, in case the conformance has
3395
- // not been fully checked.
3396
- Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3397
- if (tanType.isNull () || tanType->hasError ())
3398
- return ProtocolConformanceRef ();
3399
- return conf;
3400
- };
3401
-
3402
3378
// / Returns true if the given type conforms to `Differentiable` in the given
3403
3379
// / contxt. If `tangentVectorEqualsSelf` is true, also check whether the given
3404
3380
// / type satisfies `TangentVector == Self`.
3405
3381
static bool conformsToDifferentiable (Type type, DeclContext *DC,
3406
3382
bool tangentVectorEqualsSelf = false ) {
3407
- auto conf = getDifferentiableConformance (type, DC);
3383
+ auto &ctx = type->getASTContext ();
3384
+ auto *differentiableProto =
3385
+ ctx.getProtocol (KnownProtocolKind::Differentiable);
3386
+ auto conf = TypeChecker::conformsToProtocol (type, differentiableProto, DC);
3408
3387
if (conf.isInvalid ())
3409
3388
return false ;
3410
3389
if (!tangentVectorEqualsSelf)
3411
3390
return true ;
3412
- auto &ctx = type->getASTContext ();
3413
- Type tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3391
+ auto tanType = conf.getTypeWitnessByName (type, ctx.Id_TangentVector );
3414
3392
return type->isEqual (tanType);
3415
3393
};
3416
3394
@@ -4602,67 +4580,81 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4602
4580
// Set the resolved differentiability parameter indices in the attribute.
4603
4581
attr->setParameterIndices (resolvedDiffParamIndices);
4604
4582
4605
- // Get the original semantic result.
4606
- llvm::SmallVector<AutoDiffSemanticFunctionResultType, 1 > originalResults;
4607
- autodiff::getFunctionSemanticResultTypes (
4608
- originalFnType, originalResults,
4609
- derivative->getGenericEnvironmentOfContext ());
4610
- // Check that original function has at least one semantic result, i.e.
4611
- // that the original semantic result type is not `Void`.
4612
- if (originalResults.empty ()) {
4613
- diags
4614
- .diagnose (attr->getLocation (), diag::autodiff_attr_original_void_result,
4615
- derivative->getName ())
4616
- .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4617
- attr->setInvalid ();
4618
- return true ;
4619
- }
4620
- // Check that original function does not have multiple semantic results.
4621
- if (originalResults.size () > 1 ) {
4622
- diags
4623
- .diagnose (attr->getLocation (),
4624
- diag::autodiff_attr_original_multiple_semantic_results)
4625
- .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4626
- attr->setInvalid ();
4627
- return true ;
4628
- }
4629
- auto originalResult = originalResults.front ();
4630
- auto originalResultType = originalResult.type ;
4631
- // Check that the original semantic result conforms to `Differentiable`.
4632
- auto valueResultConf = getDifferentiableConformance (
4633
- originalResultType, derivative->getDeclContext ());
4634
- if (!valueResultConf) {
4635
- diags.diagnose (attr->getLocation (),
4636
- diag::derivative_attr_result_value_not_differentiable,
4637
- valueResultElt.getType ());
4583
+ // Compute the expected differential/pullback type.
4584
+ auto expectedLinearMapTypeOrError =
4585
+ originalFnType->getAutoDiffDerivativeFunctionLinearMapType (
4586
+ resolvedDiffParamIndices, kind.getLinearMapKind (), lookupConformance,
4587
+ /* makeSelfParamFirst*/ true );
4588
+
4589
+ // Helper for diagnosing derivative function type errors.
4590
+ auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
4591
+ switch (error.kind ) {
4592
+ case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4593
+ diags
4594
+ .diagnose (attr->getLocation (),
4595
+ diag::autodiff_attr_original_multiple_semantic_results)
4596
+ .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4597
+ attr->setInvalid ();
4598
+ return ;
4599
+ case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
4600
+ diags
4601
+ .diagnose (attr->getLocation (),
4602
+ diag::autodiff_attr_original_multiple_semantic_results)
4603
+ .highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4604
+ attr->setInvalid ();
4605
+ return ;
4606
+ case DerivativeFunctionTypeError::Kind::NonDifferentiableParameters: {
4607
+ auto *nonDiffParamIndices = error.getNonDifferentiableParameterIndices ();
4608
+ SmallVector<AnyFunctionType::Param, 4 > diffParams;
4609
+ error.functionType ->getSubsetParameters (resolvedDiffParamIndices,
4610
+ diffParams);
4611
+ for (unsigned i : range (diffParams.size ())) {
4612
+ if (!nonDiffParamIndices->contains (i))
4613
+ continue ;
4614
+ SourceLoc loc = parsedDiffParams.empty () ? attr->getLocation ()
4615
+ : parsedDiffParams[i].getLoc ();
4616
+ auto diffParamType = diffParams[i].getPlainType ();
4617
+ diags.diagnose (loc, diag::diff_params_clause_param_not_differentiable,
4618
+ diffParamType);
4619
+ }
4620
+ return ;
4621
+ }
4622
+ case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4623
+ auto originalResultType = error.getNonDifferentiableResultType ();
4624
+ diags.diagnose (attr->getLocation (),
4625
+ diag::differentiable_attr_result_not_differentiable,
4626
+ originalResultType);
4627
+ attr->setInvalid ();
4628
+ return ;
4629
+ }
4630
+ };
4631
+ // Diagnose any derivative function type errors.
4632
+ if (!expectedLinearMapTypeOrError) {
4633
+ auto error = expectedLinearMapTypeOrError.takeError ();
4634
+ handleAllErrors (std::move (error), errorHandler);
4638
4635
return true ;
4639
4636
}
4640
-
4641
- // Compute the actual differential/pullback type that we use for comparison
4642
- // with the expected type. We must canonicalize the derivative interface type
4643
- // before extracting the differential/pullback type from it, so that the
4644
- // derivative interface type generic signature is available for simplifying
4645
- // types.
4637
+ Type expectedLinearMapType = expectedLinearMapTypeOrError.get ();
4638
+ if (expectedLinearMapType->hasTypeParameter ())
4639
+ expectedLinearMapType =
4640
+ derivative->mapTypeIntoContext (expectedLinearMapType);
4641
+ if (expectedLinearMapType->hasArchetype ())
4642
+ expectedLinearMapType = expectedLinearMapType->mapTypeOutOfContext ();
4643
+
4644
+ // Compute the actual differential/pullback type for comparison with the
4645
+ // expected type. We must canonicalize the derivative interface type before
4646
+ // extracting the differential/pullback type from it so that types are
4647
+ // simplified via the canonical generic signature.
4646
4648
CanType canActualResultType = derivativeInterfaceType->getCanonicalType ();
4647
4649
while (isa<AnyFunctionType>(canActualResultType)) {
4648
4650
canActualResultType =
4649
4651
cast<AnyFunctionType>(canActualResultType).getResult ();
4650
4652
}
4651
- CanType actualFuncEltType =
4653
+ CanType actualLinearMapType =
4652
4654
cast<TupleType>(canActualResultType).getElementType (1 );
4653
4655
4654
- // Compute expected differential/pullback type.
4655
- Type expectedFuncEltType =
4656
- originalFnType->getAutoDiffDerivativeFunctionLinearMapType (
4657
- resolvedDiffParamIndices, kind.getLinearMapKind (), lookupConformance,
4658
- /* makeSelfParamFirst*/ true );
4659
- if (expectedFuncEltType->hasTypeParameter ())
4660
- expectedFuncEltType = derivative->mapTypeIntoContext (expectedFuncEltType);
4661
- if (expectedFuncEltType->hasArchetype ())
4662
- expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext ();
4663
-
4664
4656
// Check if differential/pullback type matches expected type.
4665
- if (!actualFuncEltType ->isEqual (expectedFuncEltType )) {
4657
+ if (!actualLinearMapType ->isEqual (expectedLinearMapType )) {
4666
4658
// Emit differential/pullback type mismatch error on attribute.
4667
4659
diags.diagnose (attr->getLocation (),
4668
4660
diag::derivative_attr_result_func_type_mismatch,
@@ -4675,7 +4667,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4675
4667
diags
4676
4668
.diagnose (funcEltTypeRepr->getStartLoc (),
4677
4669
diag::derivative_attr_result_func_type_mismatch_note,
4678
- funcResultElt.getName (), expectedFuncEltType )
4670
+ funcResultElt.getName (), expectedLinearMapType )
4679
4671
.highlight (funcEltTypeRepr->getSourceRange ());
4680
4672
// Emit note showing original function location, if possible.
4681
4673
if (originalAFD->getLoc ().isValid ())
0 commit comments