Skip to content

Commit 489b6e0

Browse files
committed
[AutoDiff] Clean up derivative type calculation.
Remove all assertions from `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType`. All error cases are represented by `DerivativeFunctionTypeError` now. Fix `DerivativeFunctionTypeError` error payloads and improve error case naming.
1 parent 036d59c commit 489b6e0

File tree

4 files changed

+68
-63
lines changed

4 files changed

+68
-63
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,16 @@ class DerivativeFunctionTypeError
394394
: public llvm::ErrorInfo<DerivativeFunctionTypeError> {
395395
public:
396396
enum class Kind {
397+
/// Original function type has no semantic results.
397398
NoSemanticResults,
399+
/// Original function type has multiple semantic results.
400+
// TODO(TF-1250): Support function types with multiple semantic results.
398401
MultipleSemanticResults,
399-
NonDifferentiableParameters,
402+
/// Differentiability parmeter indices are empty.
403+
NoDifferentiabilityParameters,
404+
/// A differentiability parameter does not conform to `Differentiable`.
405+
NonDifferentiableDifferentiabilityParameter,
406+
/// The original result type does not conform to `Differentiable`.
400407
NonDifferentiableResult
401408
};
402409

@@ -406,42 +413,35 @@ class DerivativeFunctionTypeError
406413
/// The error kind.
407414
Kind kind;
408415

416+
/// The type and index of a differentiability parameter or result.
417+
using TypeAndIndex = std::pair<Type, unsigned>;
418+
409419
private:
410420
union Value {
411-
IndexSubset *indices;
412-
Type type;
413-
Value(IndexSubset *indices) : indices(indices) {}
414-
Value(Type type) : type(type) {}
421+
TypeAndIndex typeAndIndex;
422+
Value(TypeAndIndex typeAndIndex) : typeAndIndex(typeAndIndex) {}
415423
Value() {}
416424
} value;
417425

418426
public:
419427
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
420428
: functionType(functionType), kind(kind), value(Value()) {
421429
assert(kind == Kind::NoSemanticResults ||
422-
kind == Kind::MultipleSemanticResults);
423-
};
424-
425-
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
426-
IndexSubset *nonDiffParameterIndices)
427-
: functionType(functionType), kind(kind), value(nonDiffParameterIndices) {
428-
assert(kind == Kind::NonDifferentiableParameters);
430+
kind == Kind::MultipleSemanticResults ||
431+
kind == Kind::NoDifferentiabilityParameters);
429432
};
430433

431434
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
432-
Type nonDiffResultType)
433-
: functionType(functionType), kind(kind), value(nonDiffResultType) {
434-
assert(kind == Kind::NonDifferentiableResult);
435+
TypeAndIndex nonDiffTypeAndIndex)
436+
: functionType(functionType), kind(kind), value(nonDiffTypeAndIndex) {
437+
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
438+
kind == Kind::NonDifferentiableResult);
435439
};
436440

437-
IndexSubset *getNonDifferentiableParameterIndices() const {
438-
assert(kind == Kind::NonDifferentiableParameters);
439-
return value.indices;
440-
}
441-
442-
Type getNonDifferentiableResultType() const {
443-
assert(kind == Kind::NonDifferentiableResult);
444-
return value.type;
441+
TypeAndIndex getNonDifferentiableTypeAndIndex() const {
442+
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
443+
kind == Kind::NonDifferentiableResult);
444+
return value.typeAndIndex;
445445
}
446446

447447
void log(raw_ostream &OS) const override;

lib/AST/AutoDiff.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,20 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
402402
case Kind::MultipleSemanticResults:
403403
OS << "has multiple semantic results";
404404
break;
405-
case Kind::NonDifferentiableParameters:
406-
OS << "has non-differentiable parameters: ";
407-
value.indices->print(OS);
405+
case Kind::NoDifferentiabilityParameters:
406+
OS << "has no differentiability parameters";
408407
break;
409-
case Kind::NonDifferentiableResult:
410-
OS << "has non-differentiable result: " << value.type;
408+
case Kind::NonDifferentiableDifferentiabilityParameter: {
409+
auto nonDiffParam = getNonDifferentiableTypeAndIndex();
410+
OS << "has non-differentiable differentiability parameter "
411+
<< nonDiffParam.second << ": " << nonDiffParam.first;
411412
break;
412413
}
414+
case Kind::NonDifferentiableResult: {
415+
auto nonDiffResult = getNonDifferentiableTypeAndIndex();
416+
OS << "has non-differentiable result " << nonDiffResult.second << ": "
417+
<< nonDiffResult.first;
418+
break;
419+
}
420+
}
413421
}

lib/AST/Type.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5171,9 +5171,11 @@ llvm::Expected<AnyFunctionType *>
51715171
AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
51725172
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
51735173
LookupConformanceFn lookupConformance, bool makeSelfParamFirst) {
5174-
assert(!parameterIndices->isEmpty() &&
5175-
"Expected at least one differentiability parameter");
51765174
auto &ctx = getASTContext();
5175+
// Error if differentiability parameter indices are empty.
5176+
if (parameterIndices->isEmpty())
5177+
return llvm::make_error<DerivativeFunctionTypeError>(
5178+
this, DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters);
51775179

51785180
// Get differentiability parameters.
51795181
SmallVector<AnyFunctionType::Param, 8> diffParams;
@@ -5202,7 +5204,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52025204
if (!resultTan) {
52035205
return llvm::make_error<DerivativeFunctionTypeError>(
52045206
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5205-
originalResultType);
5207+
std::make_pair(originalResultType, /*index*/ 0));
52065208
}
52075209
auto resultTanType = resultTan->getType();
52085210

@@ -5225,15 +5227,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52255227
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
52265228
SmallVector<AnyFunctionType::Param, 4> differentialParams;
52275229
bool hasInoutDiffParameter = false;
5228-
for (auto diffParam : diffParams) {
5230+
for (auto i : range(diffParams.size())) {
5231+
auto diffParam = diffParams[i];
52295232
auto paramType = diffParam.getPlainType();
52305233
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
52315234
// Error if paraneter has no tangent space.
52325235
if (!paramTan) {
52335236
return llvm::make_error<DerivativeFunctionTypeError>(
52345237
this,
5235-
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5236-
parameterIndices);
5238+
DerivativeFunctionTypeError::Kind::
5239+
NonDifferentiableDifferentiabilityParameter,
5240+
std::make_pair(paramType, i));
52375241
}
52385242
differentialParams.push_back(AnyFunctionType::Param(
52395243
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
@@ -5261,15 +5265,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52615265
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
52625266
SmallVector<TupleTypeElt, 4> pullbackResults;
52635267
bool hasInoutDiffParameter = false;
5264-
for (auto diffParam : diffParams) {
5268+
for (auto i : range(diffParams.size())) {
5269+
auto diffParam = diffParams[i];
52655270
auto paramType = diffParam.getPlainType();
52665271
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
52675272
// Error if paraneter has no tangent space.
52685273
if (!paramTan) {
52695274
return llvm::make_error<DerivativeFunctionTypeError>(
52705275
this,
5271-
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5272-
parameterIndices);
5276+
DerivativeFunctionTypeError::Kind::
5277+
NonDifferentiableDifferentiabilityParameter,
5278+
std::make_pair(paramType, i));
52735279
}
52745280
if (diffParam.isInOut()) {
52755281
hasInoutDiffParameter = true;

lib/Sema/TypeCheckAttr.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4570,14 +4570,9 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45704570
if (!resolvedDiffParamIndices)
45714571
return true;
45724572

4573-
// Check if the differentiability parameter indices are valid.
4574-
if (checkDifferentiabilityParameters(
4575-
originalAFD, resolvedDiffParamIndices, originalFnType,
4576-
derivative->getGenericEnvironment(), derivative->getModuleContext(),
4577-
parsedDiffParams, attr->getLocation()))
4578-
return true;
4579-
45804573
// Set the resolved differentiability parameter indices in the attribute.
4574+
// Differentiability parameter indices verification is done by
4575+
// `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType` below.
45814576
attr->setParameterIndices(resolvedDiffParamIndices);
45824577

45834578
// Compute the expected differential/pullback type.
@@ -4588,43 +4583,39 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45884583

45894584
// Helper for diagnosing derivative function type errors.
45904585
auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
4586+
attr->setInvalid();
45914587
switch (error.kind) {
45924588
case DerivativeFunctionTypeError::Kind::NoSemanticResults:
45934589
diags
45944590
.diagnose(attr->getLocation(),
45954591
diag::autodiff_attr_original_multiple_semantic_results)
45964592
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4597-
attr->setInvalid();
45984593
return;
45994594
case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
46004595
diags
46014596
.diagnose(attr->getLocation(),
46024597
diag::autodiff_attr_original_multiple_semantic_results)
46034598
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4604-
attr->setInvalid();
46054599
return;
4606-
case DerivativeFunctionTypeError::Kind::NonDifferentiableParameters: {
4607-
auto *nonDiffParamIndices = error.getNonDifferentiableParameterIndices();
4608-
SmallVector<AnyFunctionType::Param, 4> diffParams;
4609-
error.functionType->getSubsetParameters(resolvedDiffParamIndices,
4610-
diffParams);
4611-
for (unsigned i : range(diffParams.size())) {
4612-
if (!nonDiffParamIndices->contains(i))
4613-
continue;
4614-
SourceLoc loc = parsedDiffParams.empty() ? attr->getLocation()
4615-
: parsedDiffParams[i].getLoc();
4616-
auto diffParamType = diffParams[i].getPlainType();
4617-
diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable,
4618-
diffParamType);
4619-
}
4600+
case DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters:
4601+
diags.diagnose(attr->getLocation(),
4602+
diag::diff_params_clause_no_inferred_parameters);
4603+
return;
4604+
case DerivativeFunctionTypeError::Kind::
4605+
NonDifferentiableDifferentiabilityParameter: {
4606+
auto nonDiffParam = error.getNonDifferentiableTypeAndIndex();
4607+
SourceLoc loc = parsedDiffParams.empty()
4608+
? attr->getLocation()
4609+
: parsedDiffParams[nonDiffParam.second].getLoc();
4610+
diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable,
4611+
nonDiffParam.first);
46204612
return;
46214613
}
46224614
case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4623-
auto originalResultType = error.getNonDifferentiableResultType();
4615+
auto nonDiffResult = error.getNonDifferentiableTypeAndIndex();
46244616
diags.diagnose(attr->getLocation(),
46254617
diag::differentiable_attr_result_not_differentiable,
4626-
originalResultType);
4627-
attr->setInvalid();
4618+
nonDiffResult.first);
46284619
return;
46294620
}
46304621
};

0 commit comments

Comments
 (0)