Skip to content

Commit c9bbc14

Browse files
committed
[AutoDiff] Simplify @differentiable attribute type-checking.
Unify type-checking using `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType`. Delete `checkDifferentiabilityParameters` helper, which is subsumed. Update tests with minor diagnostic changes.
1 parent 489b6e0 commit c9bbc14

File tree

4 files changed

+76
-99
lines changed

4 files changed

+76
-99
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2978,8 +2978,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29782978
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
29792979
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
29802980
"attribute for transpose registration instead", ())
2981-
ERROR(differentiable_attr_void_result,none,
2982-
"cannot differentiate void function %0", (DeclName))
29832981
ERROR(differentiable_attr_overload_not_found,none,
29842982
"%0 does not have expected type %1", (DeclNameRef, Type))
29852983
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
@@ -2998,9 +2996,6 @@ ERROR(differentiable_attr_invalid_access,none,
29982996
"derivative function %0 is required to either be public or "
29992997
"'@usableFromInline' because the original function %1 is public or "
30002998
"'@usableFromInline'", (DeclNameRef, DeclName))
3001-
ERROR(differentiable_attr_result_not_differentiable,none,
3002-
"can only differentiate functions with results that conform to "
3003-
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
30042999
ERROR(differentiable_attr_protocol_req_where_clause,none,
30053000
"'@differentiable' attribute on protocol requirement cannot specify "
30063001
"'where' clause", ())
@@ -3107,6 +3102,9 @@ ERROR(autodiff_attr_original_void_result,none,
31073102
ERROR(autodiff_attr_original_multiple_semantic_results,none,
31083103
"cannot differentiate functions with both an 'inout' parameter and a "
31093104
"result", ())
3105+
ERROR(autodiff_attr_result_not_differentiable,none,
3106+
"can only differentiate functions with results that conform to "
3107+
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
31103108

31113109
// differentiation `wrt` parameters clause
31123110
ERROR(diff_function_no_parameters,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 63 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3563,51 +3563,6 @@ static IndexSubset *computeDifferentiabilityParameters(
35633563
return IndexSubset::get(ctx, parameterBits);
35643564
}
35653565

3566-
// Checks if the given differentiability parameter indices are valid for the
3567-
// given original or derivative `AbstractFunctionDecl` and original function
3568-
// type in the given derivative generic environment and module context. Returns
3569-
// true on error.
3570-
//
3571-
// The parsed differentiability parameters and attribute location are used in
3572-
// diagnostics.
3573-
static bool checkDifferentiabilityParameters(
3574-
AbstractFunctionDecl *AFD, IndexSubset *diffParamIndices,
3575-
AnyFunctionType *functionType, GenericEnvironment *derivativeGenEnv,
3576-
ModuleDecl *module, ArrayRef<ParsedAutoDiffParameter> parsedDiffParams,
3577-
SourceLoc attrLoc) {
3578-
auto &ctx = AFD->getASTContext();
3579-
auto &diags = ctx.Diags;
3580-
3581-
// Diagnose empty differentiability indices. No differentiability parameters
3582-
// were resolved or inferred.
3583-
if (diffParamIndices->isEmpty()) {
3584-
diags.diagnose(attrLoc, diag::diff_params_clause_no_inferred_parameters);
3585-
return true;
3586-
}
3587-
3588-
// Check that differentiability parameters have allowed types.
3589-
SmallVector<AnyFunctionType::Param, 4> diffParams;
3590-
functionType->getSubsetParameters(diffParamIndices, diffParams);
3591-
for (unsigned i : range(diffParams.size())) {
3592-
SourceLoc loc =
3593-
parsedDiffParams.empty() ? attrLoc : parsedDiffParams[i].getLoc();
3594-
auto diffParamType = diffParams[i].getPlainType();
3595-
if (!diffParamType->hasTypeParameter())
3596-
diffParamType = diffParamType->mapTypeOutOfContext();
3597-
if (derivativeGenEnv)
3598-
diffParamType = derivativeGenEnv->mapTypeIntoContext(diffParamType);
3599-
else
3600-
diffParamType = AFD->mapTypeIntoContext(diffParamType);
3601-
// Parameter must conform to `Differentiable`.
3602-
if (!conformsToDifferentiable(diffParamType, AFD)) {
3603-
diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable,
3604-
diffParamType);
3605-
return true;
3606-
}
3607-
}
3608-
return false;
3609-
}
3610-
36113566
// Returns the function declaration corresponding to the given function name and
36123567
// lookup context. If the base type of the function is specified, member lookup
36133568
// is performed. Otherwise, unqualified lookup is performed.
@@ -4103,9 +4058,11 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
41034058
/// `diffParamIndices`, and returns true.
41044059
bool resolveDifferentiableAttrDifferentiabilityParameters(
41054060
DifferentiableAttr *attr, AbstractFunctionDecl *original,
4106-
AnyFunctionType *derivativeFnTy, GenericEnvironment *derivativeGenEnv,
4061+
AnyFunctionType *originalFnRemappedTy, GenericEnvironment *derivativeGenEnv,
41074062
IndexSubset *&diffParamIndices) {
41084063
diffParamIndices = nullptr;
4064+
auto &ctx = original->getASTContext();
4065+
auto &diags = ctx.Diags;
41094066

41104067
// Get the parsed differentiability parameter indices, which have not yet been
41114068
// resolved. Parsed differentiability parameter indices are defined only for
@@ -4121,11 +4078,57 @@ bool resolveDifferentiableAttrDifferentiabilityParameters(
41214078
}
41224079

41234080
// Check if differentiability parameter indices are valid.
4124-
if (checkDifferentiabilityParameters(original, diffParamIndices,
4125-
derivativeFnTy, derivativeGenEnv,
4126-
original->getModuleContext(),
4127-
parsedDiffParams, attr->getLocation())) {
4081+
// Do this by compute the expected differential type and checking whether
4082+
// there is an error.
4083+
auto expectedLinearMapTypeOrError =
4084+
originalFnRemappedTy->getAutoDiffDerivativeFunctionLinearMapType(
4085+
diffParamIndices, AutoDiffLinearMapKind::Differential,
4086+
LookUpConformanceInModule(original->getModuleContext()),
4087+
/*makeSelfParamFirst*/ true);
4088+
4089+
// Helper for diagnosing derivative function type errors.
4090+
auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
41284091
attr->setInvalid();
4092+
switch (error.kind) {
4093+
case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4094+
diags
4095+
.diagnose(attr->getLocation(),
4096+
diag::autodiff_attr_original_void_result,
4097+
original->getName())
4098+
.highlight(original->getSourceRange());
4099+
return;
4100+
case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
4101+
diags
4102+
.diagnose(attr->getLocation(),
4103+
diag::autodiff_attr_original_multiple_semantic_results)
4104+
.highlight(original->getSourceRange());
4105+
return;
4106+
case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters:
4107+
diags.diagnose(attr->getLocation(),
4108+
diag::diff_params_clause_no_inferred_parameters);
4109+
return;
4110+
case DerivativeFunctionTypeError::Kind::
4111+
NonDifferentiableDifferentiabilityParameter: {
4112+
auto nonDiffParam = error.getNonDifferentiableTypeAndIndex();
4113+
SourceLoc loc = parsedDiffParams.empty()
4114+
? attr->getLocation()
4115+
: parsedDiffParams[nonDiffParam.second].getLoc();
4116+
diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable,
4117+
nonDiffParam.first);
4118+
return;
4119+
}
4120+
case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4121+
auto nonDiffResult = error.getNonDifferentiableTypeAndIndex();
4122+
diags.diagnose(attr->getLocation(),
4123+
diag::autodiff_attr_result_not_differentiable,
4124+
nonDiffResult.first);
4125+
return;
4126+
}
4127+
};
4128+
// Diagnose any derivative function type errors.
4129+
if (!expectedLinearMapTypeOrError) {
4130+
auto error = expectedLinearMapTypeOrError.takeError();
4131+
handleAllErrors(std::move(error), errorHandler);
41294132
return true;
41304133
}
41314134

@@ -4222,52 +4225,19 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(
42224225
derivativeGenEnv = derivativeGenSig->getGenericEnvironment();
42234226

42244227
// Compute the derivative function type.
4225-
auto derivativeFnTy = originalFnTy;
4228+
auto originalFnRemappedTy = originalFnTy;
42264229
if (derivativeGenEnv)
4227-
derivativeFnTy = derivativeGenEnv->mapTypeIntoContext(derivativeFnTy)
4228-
->castTo<AnyFunctionType>();
4230+
originalFnRemappedTy =
4231+
derivativeGenEnv->mapTypeIntoContext(originalFnRemappedTy)
4232+
->castTo<AnyFunctionType>();
42294233

42304234
// Resolve and validate the differentiability parameters.
42314235
IndexSubset *resolvedDiffParamIndices = nullptr;
42324236
if (resolveDifferentiableAttrDifferentiabilityParameters(
4233-
attr, original, derivativeFnTy, derivativeGenEnv,
4237+
attr, original, originalFnRemappedTy, derivativeGenEnv,
42344238
resolvedDiffParamIndices))
42354239
return nullptr;
42364240

4237-
// Get the original semantic result type.
4238-
llvm::SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
4239-
autodiff::getFunctionSemanticResultTypes(originalFnTy, originalResults,
4240-
derivativeGenEnv);
4241-
// Check that original function has at least one semantic result, i.e.
4242-
// that the original semantic result type is not `Void`.
4243-
if (originalResults.empty()) {
4244-
diags
4245-
.diagnose(attr->getLocation(), diag::autodiff_attr_original_void_result,
4246-
original->getName())
4247-
.highlight(original->getSourceRange());
4248-
attr->setInvalid();
4249-
return nullptr;
4250-
}
4251-
// Check that original function does not have multiple semantic results.
4252-
if (originalResults.size() > 1) {
4253-
diags
4254-
.diagnose(attr->getLocation(),
4255-
diag::autodiff_attr_original_multiple_semantic_results)
4256-
.highlight(original->getSourceRange());
4257-
attr->setInvalid();
4258-
return nullptr;
4259-
}
4260-
auto originalResult = originalResults.front();
4261-
auto originalResultTy = originalResult.type;
4262-
// Check that the original semantic result conforms to `Differentiable`.
4263-
if (!conformsToDifferentiable(originalResultTy, original)) {
4264-
diags.diagnose(attr->getLocation(),
4265-
diag::differentiable_attr_result_not_differentiable,
4266-
originalResultTy);
4267-
attr->setInvalid();
4268-
return nullptr;
4269-
}
4270-
42714241
if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) {
42724242
// Remove `@differentiable` attribute from storage declaration to prevent
42734243
// duplicate attribute registration during SILGen.
@@ -4336,8 +4306,6 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
43364306
if (checkIfDifferentiableProgrammingEnabled(Ctx, attr, D->getDeclContext()))
43374307
return true;
43384308
auto *derivative = cast<FuncDecl>(D);
4339-
auto lookupConformance =
4340-
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
43414309
auto originalName = attr->getOriginalFunctionName();
43424310

43434311
auto *derivativeInterfaceType =
@@ -4578,7 +4546,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45784546
// Compute the expected differential/pullback type.
45794547
auto expectedLinearMapTypeOrError =
45804548
originalFnType->getAutoDiffDerivativeFunctionLinearMapType(
4581-
resolvedDiffParamIndices, kind.getLinearMapKind(), lookupConformance,
4549+
resolvedDiffParamIndices, kind.getLinearMapKind(),
4550+
LookUpConformanceInModule(derivative->getModuleContext()),
45824551
/*makeSelfParamFirst*/ true);
45834552

45844553
// Helper for diagnosing derivative function type errors.
@@ -4588,7 +4557,8 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45884557
case DerivativeFunctionTypeError::Kind::NoSemanticResults:
45894558
diags
45904559
.diagnose(attr->getLocation(),
4591-
diag::autodiff_attr_original_multiple_semantic_results)
4560+
diag::autodiff_attr_original_void_result,
4561+
originalAFD->getName())
45924562
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
45934563
return;
45944564
case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
@@ -4614,7 +4584,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46144584
case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
46154585
auto nonDiffResult = error.getNonDifferentiableTypeAndIndex();
46164586
diags.diagnose(attr->getLocation(),
4617-
diag::differentiable_attr_result_not_differentiable,
4587+
diag::autodiff_attr_result_not_differentiable,
46184588
nonDiffResult.first);
46194589
return;
46204590
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,14 @@ extension InoutParameters {
707707
) { fatalError() }
708708
}
709709

710+
// Test no semantic results.
711+
712+
func noSemanticResults(_ x: Float) {}
713+
714+
// expected-error @+1 {{cannot differentiate void function 'noSemanticResults'}}
715+
@derivative(of: noSemanticResults)
716+
func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}
717+
710718
// Test multiple semantic results.
711719

712720
extension InoutParameters {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func invalidDiffWrtClass(_ x: Class) -> Class {
9292
}
9393

9494
protocol Proto {}
95-
// expected-error @+1 {{can only differentiate with respect to parameters that conform to 'Differentiable', but 'Proto' does not conform to 'Differentiable'}}
95+
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Proto' does not conform to 'Differentiable'}}
9696
@differentiable(wrt: x)
9797
func invalidDiffWrtExistential(_ x: Proto) -> Proto {
9898
return x
@@ -384,6 +384,7 @@ struct TF_521<T: FloatingPoint> {
384384
var real: T
385385
var imaginary: T
386386

387+
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'TF_521<T>' does not conform to 'Differentiable'}}
387388
@differentiable(where T: Differentiable, T == T.TangentVector)
388389
init(real: T = 0, imaginary: T = 0) {
389390
self.real = real

0 commit comments

Comments
 (0)