Skip to content

[AutoDiff] Clean up derivative type calculation. #31755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,16 @@ class DerivativeFunctionTypeError
: public llvm::ErrorInfo<DerivativeFunctionTypeError> {
public:
enum class Kind {
/// Original function type has no semantic results.
NoSemanticResults,
/// Original function type has multiple semantic results.
// TODO(TF-1250): Support function types with multiple semantic results.
MultipleSemanticResults,
NonDifferentiableParameters,
/// Differentiability parmeter indices are empty.
NoDifferentiabilityParameters,
/// A differentiability parameter does not conform to `Differentiable`.
NonDifferentiableDifferentiabilityParameter,
/// The original result type does not conform to `Differentiable`.
NonDifferentiableResult
};

Expand All @@ -406,42 +413,35 @@ class DerivativeFunctionTypeError
/// The error kind.
Kind kind;

/// The type and index of a differentiability parameter or result.
using TypeAndIndex = std::pair<Type, unsigned>;

private:
union Value {
IndexSubset *indices;
Type type;
Value(IndexSubset *indices) : indices(indices) {}
Value(Type type) : type(type) {}
TypeAndIndex typeAndIndex;
Value(TypeAndIndex typeAndIndex) : typeAndIndex(typeAndIndex) {}
Value() {}
} value;

public:
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
: functionType(functionType), kind(kind), value(Value()) {
assert(kind == Kind::NoSemanticResults ||
kind == Kind::MultipleSemanticResults);
};

explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
IndexSubset *nonDiffParameterIndices)
: functionType(functionType), kind(kind), value(nonDiffParameterIndices) {
assert(kind == Kind::NonDifferentiableParameters);
kind == Kind::MultipleSemanticResults ||
kind == Kind::NoDifferentiabilityParameters);
};

explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
Type nonDiffResultType)
: functionType(functionType), kind(kind), value(nonDiffResultType) {
assert(kind == Kind::NonDifferentiableResult);
TypeAndIndex nonDiffTypeAndIndex)
: functionType(functionType), kind(kind), value(nonDiffTypeAndIndex) {
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
kind == Kind::NonDifferentiableResult);
};

IndexSubset *getNonDifferentiableParameterIndices() const {
assert(kind == Kind::NonDifferentiableParameters);
return value.indices;
}

Type getNonDifferentiableResultType() const {
assert(kind == Kind::NonDifferentiableResult);
return value.type;
TypeAndIndex getNonDifferentiableTypeAndIndex() const {
assert(kind == Kind::NonDifferentiableDifferentiabilityParameter ||
kind == Kind::NonDifferentiableResult);
return value.typeAndIndex;
}

void log(raw_ostream &OS) const override;
Expand Down
8 changes: 3 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2978,8 +2978,6 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
ERROR(differentiable_attr_no_vjp_or_jvp_when_linear,none,
"cannot specify 'vjp:' or 'jvp:' for linear functions; use '@transpose' "
"attribute for transpose registration instead", ())
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_overload_not_found,none,
"%0 does not have expected type %1", (DeclNameRef, Type))
// TODO(TF-482): Change duplicate `@differentiable` attribute diagnostic to also
Expand All @@ -2998,9 +2996,6 @@ ERROR(differentiable_attr_invalid_access,none,
"derivative function %0 is required to either be public or "
"'@usableFromInline' because the original function %1 is public or "
"'@usableFromInline'", (DeclNameRef, DeclName))
ERROR(differentiable_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
ERROR(differentiable_attr_protocol_req_where_clause,none,
"'@differentiable' attribute on protocol requirement cannot specify "
"'where' clause", ())
Expand Down Expand Up @@ -3107,6 +3102,9 @@ ERROR(autodiff_attr_original_void_result,none,
ERROR(autodiff_attr_original_multiple_semantic_results,none,
"cannot differentiate functions with both an 'inout' parameter and a "
"result", ())
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))

// differentiation `wrt` parameters clause
ERROR(diff_function_no_parameters,none,
Expand Down
18 changes: 13 additions & 5 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,20 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
case Kind::MultipleSemanticResults:
OS << "has multiple semantic results";
break;
case Kind::NonDifferentiableParameters:
OS << "has non-differentiable parameters: ";
value.indices->print(OS);
case Kind::NoDifferentiabilityParameters:
OS << "has no differentiability parameters";
break;
case Kind::NonDifferentiableResult:
OS << "has non-differentiable result: " << value.type;
case Kind::NonDifferentiableDifferentiabilityParameter: {
auto nonDiffParam = getNonDifferentiableTypeAndIndex();
OS << "has non-differentiable differentiability parameter "
<< nonDiffParam.second << ": " << nonDiffParam.first;
break;
}
case Kind::NonDifferentiableResult: {
auto nonDiffResult = getNonDifferentiableTypeAndIndex();
OS << "has non-differentiable result " << nonDiffResult.second << ": "
<< nonDiffResult.first;
break;
}
}
}
24 changes: 15 additions & 9 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5171,9 +5171,11 @@ llvm::Expected<AnyFunctionType *>
AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
LookupConformanceFn lookupConformance, bool makeSelfParamFirst) {
assert(!parameterIndices->isEmpty() &&
"Expected at least one differentiability parameter");
auto &ctx = getASTContext();
// Error if differentiability parameter indices are empty.
if (parameterIndices->isEmpty())
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NoDifferentiabilityParameters);

// Get differentiability parameters.
SmallVector<AnyFunctionType::Param, 8> diffParams;
Expand Down Expand Up @@ -5202,7 +5204,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
if (!resultTan) {
return llvm::make_error<DerivativeFunctionTypeError>(
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
originalResultType);
std::make_pair(originalResultType, /*index*/ 0));
}
auto resultTanType = resultTan->getType();

Expand All @@ -5225,15 +5227,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
SmallVector<AnyFunctionType::Param, 4> differentialParams;
bool hasInoutDiffParameter = false;
for (auto diffParam : diffParams) {
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
// Error if paraneter has no tangent space.
if (!paramTan) {
return llvm::make_error<DerivativeFunctionTypeError>(
this,
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
parameterIndices);
DerivativeFunctionTypeError::Kind::
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));
}
differentialParams.push_back(AnyFunctionType::Param(
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
Expand Down Expand Up @@ -5261,15 +5265,17 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
SmallVector<TupleTypeElt, 4> pullbackResults;
bool hasInoutDiffParameter = false;
for (auto diffParam : diffParams) {
for (auto i : range(diffParams.size())) {
auto diffParam = diffParams[i];
auto paramType = diffParam.getPlainType();
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
// Error if paraneter has no tangent space.
if (!paramTan) {
return llvm::make_error<DerivativeFunctionTypeError>(
this,
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
parameterIndices);
DerivativeFunctionTypeError::Kind::
NonDifferentiableDifferentiabilityParameter,
std::make_pair(paramType, i));
}
if (diffParam.isInOut()) {
hasInoutDiffParameter = true;
Expand Down
Loading