@@ -3563,51 +3563,6 @@ static IndexSubset *computeDifferentiabilityParameters(
3563
3563
return IndexSubset::get (ctx, parameterBits);
3564
3564
}
3565
3565
3566
- // Checks if the given differentiability parameter indices are valid for the
3567
- // given original or derivative `AbstractFunctionDecl` and original function
3568
- // type in the given derivative generic environment and module context. Returns
3569
- // true on error.
3570
- //
3571
- // The parsed differentiability parameters and attribute location are used in
3572
- // diagnostics.
3573
- static bool checkDifferentiabilityParameters (
3574
- AbstractFunctionDecl *AFD, IndexSubset *diffParamIndices,
3575
- AnyFunctionType *functionType, GenericEnvironment *derivativeGenEnv,
3576
- ModuleDecl *module , ArrayRef<ParsedAutoDiffParameter> parsedDiffParams,
3577
- SourceLoc attrLoc) {
3578
- auto &ctx = AFD->getASTContext ();
3579
- auto &diags = ctx.Diags ;
3580
-
3581
- // Diagnose empty differentiability indices. No differentiability parameters
3582
- // were resolved or inferred.
3583
- if (diffParamIndices->isEmpty ()) {
3584
- diags.diagnose (attrLoc, diag::diff_params_clause_no_inferred_parameters);
3585
- return true ;
3586
- }
3587
-
3588
- // Check that differentiability parameters have allowed types.
3589
- SmallVector<AnyFunctionType::Param, 4 > diffParams;
3590
- functionType->getSubsetParameters (diffParamIndices, diffParams);
3591
- for (unsigned i : range (diffParams.size ())) {
3592
- SourceLoc loc =
3593
- parsedDiffParams.empty () ? attrLoc : parsedDiffParams[i].getLoc ();
3594
- auto diffParamType = diffParams[i].getPlainType ();
3595
- if (!diffParamType->hasTypeParameter ())
3596
- diffParamType = diffParamType->mapTypeOutOfContext ();
3597
- if (derivativeGenEnv)
3598
- diffParamType = derivativeGenEnv->mapTypeIntoContext (diffParamType);
3599
- else
3600
- diffParamType = AFD->mapTypeIntoContext (diffParamType);
3601
- // Parameter must conform to `Differentiable`.
3602
- if (!conformsToDifferentiable (diffParamType, AFD)) {
3603
- diags.diagnose (loc, diag::diff_params_clause_param_not_differentiable,
3604
- diffParamType);
3605
- return true ;
3606
- }
3607
- }
3608
- return false ;
3609
- }
3610
-
3611
3566
// Returns the function declaration corresponding to the given function name and
3612
3567
// lookup context. If the base type of the function is specified, member lookup
3613
3568
// is performed. Otherwise, unqualified lookup is performed.
@@ -4103,9 +4058,11 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
4103
4058
// / `diffParamIndices`, and returns true.
4104
4059
bool resolveDifferentiableAttrDifferentiabilityParameters (
4105
4060
DifferentiableAttr *attr, AbstractFunctionDecl *original,
4106
- AnyFunctionType *derivativeFnTy , GenericEnvironment *derivativeGenEnv,
4061
+ AnyFunctionType *originalFnRemappedTy , GenericEnvironment *derivativeGenEnv,
4107
4062
IndexSubset *&diffParamIndices) {
4108
4063
diffParamIndices = nullptr ;
4064
+ auto &ctx = original->getASTContext ();
4065
+ auto &diags = ctx.Diags ;
4109
4066
4110
4067
// Get the parsed differentiability parameter indices, which have not yet been
4111
4068
// resolved. Parsed differentiability parameter indices are defined only for
@@ -4121,11 +4078,57 @@ bool resolveDifferentiableAttrDifferentiabilityParameters(
4121
4078
}
4122
4079
4123
4080
// Check if differentiability parameter indices are valid.
4124
- if (checkDifferentiabilityParameters (original, diffParamIndices,
4125
- derivativeFnTy, derivativeGenEnv,
4126
- original->getModuleContext (),
4127
- parsedDiffParams, attr->getLocation ())) {
4081
+ // Do this by compute the expected differential type and checking whether
4082
+ // there is an error.
4083
+ auto expectedLinearMapTypeOrError =
4084
+ originalFnRemappedTy->getAutoDiffDerivativeFunctionLinearMapType (
4085
+ diffParamIndices, AutoDiffLinearMapKind::Differential,
4086
+ LookUpConformanceInModule (original->getModuleContext ()),
4087
+ /* makeSelfParamFirst*/ true );
4088
+
4089
+ // Helper for diagnosing derivative function type errors.
4090
+ auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
4128
4091
attr->setInvalid ();
4092
+ switch (error.kind ) {
4093
+ case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4094
+ diags
4095
+ .diagnose (attr->getLocation (),
4096
+ diag::autodiff_attr_original_void_result,
4097
+ original->getName ())
4098
+ .highlight (original->getSourceRange ());
4099
+ return ;
4100
+ case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
4101
+ diags
4102
+ .diagnose (attr->getLocation (),
4103
+ diag::autodiff_attr_original_multiple_semantic_results)
4104
+ .highlight (original->getSourceRange ());
4105
+ return ;
4106
+ case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters:
4107
+ diags.diagnose (attr->getLocation (),
4108
+ diag::diff_params_clause_no_inferred_parameters);
4109
+ return ;
4110
+ case DerivativeFunctionTypeError::Kind::
4111
+ NonDifferentiableDifferentiabilityParameter: {
4112
+ auto nonDiffParam = error.getNonDifferentiableTypeAndIndex ();
4113
+ SourceLoc loc = parsedDiffParams.empty ()
4114
+ ? attr->getLocation ()
4115
+ : parsedDiffParams[nonDiffParam.second ].getLoc ();
4116
+ diags.diagnose (loc, diag::diff_params_clause_param_not_differentiable,
4117
+ nonDiffParam.first );
4118
+ return ;
4119
+ }
4120
+ case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4121
+ auto nonDiffResult = error.getNonDifferentiableTypeAndIndex ();
4122
+ diags.diagnose (attr->getLocation (),
4123
+ diag::autodiff_attr_result_not_differentiable,
4124
+ nonDiffResult.first );
4125
+ return ;
4126
+ }
4127
+ };
4128
+ // Diagnose any derivative function type errors.
4129
+ if (!expectedLinearMapTypeOrError) {
4130
+ auto error = expectedLinearMapTypeOrError.takeError ();
4131
+ handleAllErrors (std::move (error), errorHandler);
4129
4132
return true ;
4130
4133
}
4131
4134
@@ -4222,52 +4225,19 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
4222
4225
derivativeGenEnv = derivativeGenSig->getGenericEnvironment ();
4223
4226
4224
4227
// Compute the derivative function type.
4225
- auto derivativeFnTy = originalFnTy;
4228
+ auto originalFnRemappedTy = originalFnTy;
4226
4229
if (derivativeGenEnv)
4227
- derivativeFnTy = derivativeGenEnv->mapTypeIntoContext (derivativeFnTy)
4228
- ->castTo <AnyFunctionType>();
4230
+ originalFnRemappedTy =
4231
+ derivativeGenEnv->mapTypeIntoContext (originalFnRemappedTy)
4232
+ ->castTo <AnyFunctionType>();
4229
4233
4230
4234
// Resolve and validate the differentiability parameters.
4231
4235
IndexSubset *resolvedDiffParamIndices = nullptr ;
4232
4236
if (resolveDifferentiableAttrDifferentiabilityParameters (
4233
- attr, original, derivativeFnTy , derivativeGenEnv,
4237
+ attr, original, originalFnRemappedTy , derivativeGenEnv,
4234
4238
resolvedDiffParamIndices))
4235
4239
return nullptr ;
4236
4240
4237
- // Get the original semantic result type.
4238
- llvm::SmallVector<AutoDiffSemanticFunctionResultType, 1 > originalResults;
4239
- autodiff::getFunctionSemanticResultTypes (originalFnTy, originalResults,
4240
- derivativeGenEnv);
4241
- // Check that original function has at least one semantic result, i.e.
4242
- // that the original semantic result type is not `Void`.
4243
- if (originalResults.empty ()) {
4244
- diags
4245
- .diagnose (attr->getLocation (), diag::autodiff_attr_original_void_result,
4246
- original->getName ())
4247
- .highlight (original->getSourceRange ());
4248
- attr->setInvalid ();
4249
- return nullptr ;
4250
- }
4251
- // Check that original function does not have multiple semantic results.
4252
- if (originalResults.size () > 1 ) {
4253
- diags
4254
- .diagnose (attr->getLocation (),
4255
- diag::autodiff_attr_original_multiple_semantic_results)
4256
- .highlight (original->getSourceRange ());
4257
- attr->setInvalid ();
4258
- return nullptr ;
4259
- }
4260
- auto originalResult = originalResults.front ();
4261
- auto originalResultTy = originalResult.type ;
4262
- // Check that the original semantic result conforms to `Differentiable`.
4263
- if (!conformsToDifferentiable (originalResultTy, original)) {
4264
- diags.diagnose (attr->getLocation (),
4265
- diag::differentiable_attr_result_not_differentiable,
4266
- originalResultTy);
4267
- attr->setInvalid ();
4268
- return nullptr ;
4269
- }
4270
-
4271
4241
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
4272
4242
// Remove `@differentiable` attribute from storage declaration to prevent
4273
4243
// duplicate attribute registration during SILGen.
@@ -4336,8 +4306,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4336
4306
if (checkIfDifferentiableProgrammingEnabled (Ctx, attr, D->getDeclContext ()))
4337
4307
return true ;
4338
4308
auto *derivative = cast<FuncDecl>(D);
4339
- auto lookupConformance =
4340
- LookUpConformanceInModule (D->getDeclContext ()->getParentModule ());
4341
4309
auto originalName = attr->getOriginalFunctionName ();
4342
4310
4343
4311
auto *derivativeInterfaceType =
@@ -4578,7 +4546,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4578
4546
// Compute the expected differential/pullback type.
4579
4547
auto expectedLinearMapTypeOrError =
4580
4548
originalFnType->getAutoDiffDerivativeFunctionLinearMapType (
4581
- resolvedDiffParamIndices, kind.getLinearMapKind (), lookupConformance,
4549
+ resolvedDiffParamIndices, kind.getLinearMapKind (),
4550
+ LookUpConformanceInModule (derivative->getModuleContext ()),
4582
4551
/* makeSelfParamFirst*/ true );
4583
4552
4584
4553
// Helper for diagnosing derivative function type errors.
@@ -4588,7 +4557,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4588
4557
case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4589
4558
diags
4590
4559
.diagnose (attr->getLocation (),
4591
- diag::autodiff_attr_original_multiple_semantic_results)
4560
+ diag::autodiff_attr_original_void_result,
4561
+ originalAFD->getName ())
4592
4562
.highlight (attr->getOriginalFunctionName ().Loc .getSourceRange ());
4593
4563
return ;
4594
4564
case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
@@ -4614,7 +4584,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
4614
4584
case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4615
4585
auto nonDiffResult = error.getNonDifferentiableTypeAndIndex ();
4616
4586
diags.diagnose (attr->getLocation (),
4617
- diag::differentiable_attr_result_not_differentiable ,
4587
+ diag::autodiff_attr_result_not_differentiable ,
4618
4588
nonDiffResult.first );
4619
4589
return ;
4620
4590
}
0 commit comments