Skip to content

Commit 2e3994e

Browse files
authored
Merge pull request #31755 from dan-zheng/revamp-derivative-type-calculation
[AutoDiff] Clean up derivative type calculation.
2 parents 8aceb03 + c9bbc14 commit 2e3994e

File tree

7 files changed

+144
-162
lines changed

7 files changed

+144
-162
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,16 @@ class DerivativeFunctionTypeError
394394
: public llvm::ErrorInfo<DerivativeFunctionTypeError> {
395395
public:
396396
enum class Kind {
397+
/// Original function type has no semantic results.
397398
NoSemanticResults,
399+
/// Original function type has multiple semantic results.
400+
// TODO(TF-1250): Support function types with multiple semantic results.
398401
MultipleSemanticResults,
399-
NonDifferentiableParameters,
402+
/// Differentiability parmeter indices are empty.
403+
NoDifferentiabilityParameters,
404+
/// A differentiability parameter does not conform to `Differentiable`.
405+
NonDifferentiableDifferentiabilityParameter,
406+
/// The original result type does not conform to `Differentiable`.
400407
NonDifferentiableResult
401408
};
402409

@@ -406,42 +413,35 @@ class DerivativeFunctionTypeError
406413
/// The error kind.
407414
Kind kind;
408415

416+
/// The type and index of a differentiability parameter or result.
417+
using TypeAndIndex = std::pair<Type, unsigned>;
418+
409419
private:
410420
union Value {
411-
IndexSubset *indices;
412-
Type type;
413-
Value(IndexSubset *indices) : indices(indices) {}
414-
Value(Type type) : type(type) {}
421+
TypeAndIndex typeAndIndex;
422+
Value(TypeAndIndex typeAndIndex) : typeAndIndex(typeAndIndex) {}
415423
Value() {}
416424
} value;
417425

418426
public:
419427
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
420428
: functionType(functionType), kind(kind), value(Value()) {
421429
assert(kind == Kind::NoSemanticResults ||
422-
kind == Kind::MultipleSemanticResults);
423-
};
424-
425-
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
426-
IndexSubset *nonDiffParameterIndices)
427-
: functionType(functionType), kind(kind), value(nonDiffParameterIndices) {
428-
assert(kind == Kind::NonDifferentiableParameters);
430+
kind == Kind::MultipleSemanticResults ||
431+
kind == Kind::NoDifferentiabilityParameters);
429432
};
430433

431434
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
432-
Type nonDiffResultType)
433-
: functionType(functionType), kind(kind), value(nonDiffResultType) {
434-
assert(kind == Kind::NonDifferentiableResult);
435+
TypeAndIndex nonDiffTypeAndIndex)
436+
: functionType(functionType), kind(kind), value(nonDiffTypeAndIndex) {
437+
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
438+
kind == Kind::NonDifferentiableResult);
435439
};
436440

437-
IndexSubset *getNonDifferentiableParameterIndices() const {
438-
assert(kind == Kind::NonDifferentiableParameters);
439-
return value.indices;
440-
}
441-
442-
Type getNonDifferentiableResultType() const {
443-
assert(kind == Kind::NonDifferentiableResult);
444-
return value.type;
441+
TypeAndIndex getNonDifferentiableTypeAndIndex() const {
442+
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
443+
kind == Kind::NonDifferentiableResult);
444+
return value.typeAndIndex;
445445
}
446446

447447
void log(raw_ostream &OS) const override;

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,8 +2990,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
29902990
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
29912991
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
29922992
"attribute for transpose registration instead", ())
2993-
ERROR(differentiable_attr_void_result,none,
2994-
"cannot differentiate void function %0", (DeclName))
29952993
ERROR(differentiable_attr_overload_not_found,none,
29962994
"%0 does not have expected type %1", (DeclNameRef, Type))
29972995
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
@@ -3010,9 +3008,6 @@ ERROR(differentiable_attr_invalid_access,none,
30103008
"derivative function %0 is required to either be public or "
30113009
"'@usableFromInline' because the original function %1 is public or "
30123010
"'@usableFromInline'", (DeclNameRef, DeclName))
3013-
ERROR(differentiable_attr_result_not_differentiable,none,
3014-
"can only differentiate functions with results that conform to "
3015-
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
30163011
ERROR(differentiable_attr_protocol_req_where_clause,none,
30173012
"'@differentiable' attribute on protocol requirement cannot specify "
30183013
"'where' clause", ())
@@ -3119,6 +3114,9 @@ ERROR(autodiff_attr_original_void_result,none,
31193114
ERROR(autodiff_attr_original_multiple_semantic_results,none,
31203115
"cannot differentiate functions with both an 'inout' parameter and a "
31213116
"result", ())
3117+
ERROR(autodiff_attr_result_not_differentiable,none,
3118+
"can only differentiate functions with results that conform to "
3119+
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
31223120

31233121
// differentiation `wrt` parameters clause
31243122
ERROR(diff_function_no_parameters,none,

lib/AST/AutoDiff.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,12 +402,20 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
402402
case Kind::MultipleSemanticResults:
403403
OS << "has multiple semantic results";
404404
break;
405-
case Kind::NonDifferentiableParameters:
406-
OS << "has non-differentiable parameters: ";
407-
value.indices->print(OS);
405+
case Kind::NoDifferentiabilityParameters:
406+
OS << "has no differentiability parameters";
408407
break;
409-
case Kind::NonDifferentiableResult:
410-
OS << "has non-differentiable result: " << value.type;
408+
case Kind::NonDifferentiableDifferentiabilityParameter: {
409+
auto nonDiffParam = getNonDifferentiableTypeAndIndex();
410+
OS << "has non-differentiable differentiability parameter "
411+
<< nonDiffParam.second << ": " << nonDiffParam.first;
411412
break;
412413
}
414+
case Kind::NonDifferentiableResult: {
415+
auto nonDiffResult = getNonDifferentiableTypeAndIndex();
416+
OS << "has non-differentiable result " << nonDiffResult.second << ": "
417+
<< nonDiffResult.first;
418+
break;
419+
}
420+
}
413421
}

lib/AST/Type.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5172,9 +5172,11 @@ llvm::Expected<AnyFunctionType *>
51725172
AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
51735173
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
51745174
LookupConformanceFn lookupConformance, bool makeSelfParamFirst) {
5175-
assert(!parameterIndices->isEmpty() &&
5176-
"Expected at least one differentiability parameter");
51775175
auto &ctx = getASTContext();
5176+
// Error if differentiability parameter indices are empty.
5177+
if (parameterIndices->isEmpty())
5178+
return llvm::make_error<DerivativeFunctionTypeError>(
5179+
this, DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters);
51785180

51795181
// Get differentiability parameters.
51805182
SmallVector<AnyFunctionType::Param, 8> diffParams;
@@ -5203,7 +5205,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52035205
if (!resultTan) {
52045206
return llvm::make_error<DerivativeFunctionTypeError>(
52055207
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5206-
originalResultType);
5208+
std::make_pair(originalResultType, /*index*/ 0));
52075209
}
52085210
auto resultTanType = resultTan->getType();
52095211

@@ -5226,15 +5228,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52265228
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
52275229
SmallVector<AnyFunctionType::Param, 4> differentialParams;
52285230
bool hasInoutDiffParameter = false;
5229-
for (auto diffParam : diffParams) {
5231+
for (auto i : range(diffParams.size())) {
5232+
auto diffParam = diffParams[i];
52305233
auto paramType = diffParam.getPlainType();
52315234
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
52325235
// Error if paraneter has no tangent space.
52335236
if (!paramTan) {
52345237
return llvm::make_error<DerivativeFunctionTypeError>(
52355238
this,
5236-
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5237-
parameterIndices);
5239+
DerivativeFunctionTypeError::Kind::
5240+
NonDifferentiableDifferentiabilityParameter,
5241+
std::make_pair(paramType, i));
52385242
}
52395243
differentialParams.push_back(AnyFunctionType::Param(
52405244
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
@@ -5262,15 +5266,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52625266
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
52635267
SmallVector<TupleTypeElt, 4> pullbackResults;
52645268
bool hasInoutDiffParameter = false;
5265-
for (auto diffParam : diffParams) {
5269+
for (auto i : range(diffParams.size())) {
5270+
auto diffParam = diffParams[i];
52665271
auto paramType = diffParam.getPlainType();
52675272
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
52685273
// Error if paraneter has no tangent space.
52695274
if (!paramTan) {
52705275
return llvm::make_error<DerivativeFunctionTypeError>(
52715276
this,
5272-
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5273-
parameterIndices);
5277+
DerivativeFunctionTypeError::Kind::
5278+
NonDifferentiableDifferentiabilityParameter,
5279+
std::make_pair(paramType, i));
52745280
}
52755281
if (diffParam.isInOut()) {
52765282
hasInoutDiffParameter = true;

0 commit comments

Comments
 (0)