Skip to content

Commit 8500cf8

Browse files
authored
[AutoDiff] Revamp derivative type calculation using llvm::Expected. (#31727)
Create `DerivativeFunctionTypeError` representing all potential derivative function type calculation errors. Make `AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType` return `llvm::Expected<AnyFunctionType *>`. This is much safer and caller-friendly than performing assertions. Delete hacks in `@differentiable` and `@derivative` attribute type-checking for verifying that `Differentiable.TangentVector` type witnesses are valid: this is no longer necessary. Robust fix for TF-521: invalid `Differentiable` conformances during `@derivative` attribute type-checking. Resolves SR-12793: bad interaction between `@differentiable` and `@derivative` attribute type-checking and `Differentiable` derived conformances.
1 parent 7fbfbc5 commit 8500cf8

File tree

8 files changed

+226
-88
lines changed

8 files changed

+226
-88
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "swift/Basic/Range.h"
2828
#include "swift/Basic/SourceLoc.h"
2929
#include "llvm/ADT/StringExtras.h"
30+
#include "llvm/Support/Error.h"
3031

3132
namespace swift {
3233

@@ -388,6 +389,68 @@ class TangentSpace {
388389
NominalTypeDecl *getNominal() const;
389390
};
390391

392+
/// A derivative function type calculation error.
393+
class DerivativeFunctionTypeError
394+
: public llvm::ErrorInfo<DerivativeFunctionTypeError> {
395+
public:
396+
enum class Kind {
397+
NoSemanticResults,
398+
MultipleSemanticResults,
399+
NonDifferentiableParameters,
400+
NonDifferentiableResult
401+
};
402+
403+
static const char ID;
404+
/// The original function type.
405+
AnyFunctionType *functionType;
406+
/// The error kind.
407+
Kind kind;
408+
409+
private:
410+
union Value {
411+
IndexSubset *indices;
412+
Type type;
413+
Value(IndexSubset *indices) : indices(indices) {}
414+
Value(Type type) : type(type) {}
415+
Value() {}
416+
} value;
417+
418+
public:
419+
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind)
420+
: functionType(functionType), kind(kind), value(Value()) {
421+
assert(kind == Kind::NoSemanticResults ||
422+
kind == Kind::MultipleSemanticResults);
423+
};
424+
425+
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
426+
IndexSubset *nonDiffParameterIndices)
427+
: functionType(functionType), kind(kind), value(nonDiffParameterIndices) {
428+
assert(kind == Kind::NonDifferentiableParameters);
429+
};
430+
431+
explicit DerivativeFunctionTypeError(AnyFunctionType *functionType, Kind kind,
432+
Type nonDiffResultType)
433+
: functionType(functionType), kind(kind), value(nonDiffResultType) {
434+
assert(kind == Kind::NonDifferentiableResult);
435+
};
436+
437+
IndexSubset *getNonDifferentiableParameterIndices() const {
438+
assert(kind == Kind::NonDifferentiableParameters);
439+
return value.indices;
440+
}
441+
442+
Type getNonDifferentiableResultType() const {
443+
assert(kind == Kind::NonDifferentiableResult);
444+
return value.type;
445+
}
446+
447+
void log(raw_ostream &OS) const override;
448+
449+
std::error_code convertToErrorCode() const override {
450+
return llvm::inconvertibleErrorCode();
451+
}
452+
};
453+
391454
/// The key type used for uniquing `SILDifferentiabilityWitness` in
392455
/// `SILModule`: original function name, parameter indices, result indices, and
393456
/// derivative generic signature.

include/swift/AST/Types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "llvm/ADT/PointerEmbeddedInt.h"
4040
#include "llvm/ADT/PointerUnion.h"
4141
#include "llvm/ADT/SmallBitVector.h"
42+
#include "llvm/Support/Error.h"
4243
#include "llvm/Support/ErrorHandling.h"
4344
#include "llvm/Support/TrailingObjects.h"
4445

@@ -3350,7 +3351,7 @@ class AnyFunctionType : public TypeBase {
33503351
/// first. `makeSelfParamFirst` should be true when working with user-facing
33513352
/// derivative function types, e.g. when type-checking `@differentiable` and
33523353
/// `@derivative` attributes.
3353-
AnyFunctionType *getAutoDiffDerivativeFunctionLinearMapType(
3354+
llvm::Expected<AnyFunctionType *> getAutoDiffDerivativeFunctionLinearMapType(
33543355
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
33553356
LookupConformanceFn lookupConformance, bool makeSelfParamFirst = false);
33563357

lib/AST/AutoDiff.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,26 @@ NominalTypeDecl *TangentSpace::getNominal() const {
388388
assert(isTangentVector());
389389
return getTangentVector()->getNominalOrBoundGenericNominal();
390390
}
391+
392+
const char DerivativeFunctionTypeError::ID = '\0';
393+
394+
void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
395+
OS << "original function type '";
396+
functionType->print(OS);
397+
OS << "' ";
398+
switch (kind) {
399+
case Kind::NoSemanticResults:
400+
OS << "has no semantic results ('Void' result)";
401+
break;
402+
case Kind::MultipleSemanticResults:
403+
OS << "has multiple semantic results";
404+
break;
405+
case Kind::NonDifferentiableParameters:
406+
OS << "has non-differentiable parameters: ";
407+
value.indices->print(OS);
408+
break;
409+
case Kind::NonDifferentiableResult:
410+
OS << "has non-differentiable result: " << value.type;
411+
break;
412+
}
413+
}

lib/AST/Type.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5136,9 +5136,11 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
51365136

51375137
auto originalResult = curryLevels.back()->getResult();
51385138

5139-
Type linearMapType = getAutoDiffDerivativeFunctionLinearMapType(
5139+
auto linearMapTypeExpected = getAutoDiffDerivativeFunctionLinearMapType(
51405140
parameterIndices, kind.getLinearMapKind(), lookupConformance,
51415141
makeSelfParamFirst);
5142+
assert(linearMapTypeExpected && "Linear map type is invalid");
5143+
Type linearMapType = linearMapTypeExpected.get();
51425144

51435145
// Build the full derivative function type: `(T...) -> (R, LinearMapType)`.
51445146
SmallVector<TupleTypeElt, 2> retElts;
@@ -5165,7 +5167,8 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
51655167
return derivativeFunctionType;
51665168
}
51675169

5168-
AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
5170+
llvm::Expected<AnyFunctionType *>
5171+
AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
51695172
IndexSubset *parameterIndices, AutoDiffLinearMapKind kind,
51705173
LookupConformanceFn lookupConformance, bool makeSelfParamFirst) {
51715174
assert(!parameterIndices->isEmpty() &&
@@ -5180,15 +5183,27 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
51805183
// Get the original semantic result type.
51815184
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
51825185
autodiff::getFunctionSemanticResultTypes(this, originalResults);
5183-
assert(originalResults.size() == 1 &&
5184-
"Only functions with one semantic result are currently supported");
5186+
// Error if no original semantic results.
5187+
if (originalResults.empty())
5188+
return llvm::make_error<DerivativeFunctionTypeError>(
5189+
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
5190+
// Error if multiple original semantic results.
5191+
// TODO(TF-1250): Support functions with multiple semantic results.
5192+
if (originalResults.size() > 1)
5193+
return llvm::make_error<DerivativeFunctionTypeError>(
5194+
this, DerivativeFunctionTypeError::Kind::MultipleSemanticResults);
51855195
auto originalResult = originalResults.front();
51865196
auto originalResultType = originalResult.type;
51875197

51885198
// Get the original semantic result type's `TangentVector` associated type.
51895199
auto resultTan =
51905200
originalResultType->getAutoDiffTangentSpace(lookupConformance);
5191-
assert(resultTan && "Original result has no tangent space?");
5201+
// Error if original semantic result has no tangent space.
5202+
if (!resultTan) {
5203+
return llvm::make_error<DerivativeFunctionTypeError>(
5204+
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5205+
originalResultType);
5206+
}
51925207
auto resultTanType = resultTan->getType();
51935208

51945209
// Compute the result linear map function type.
@@ -5213,7 +5228,13 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52135228
for (auto diffParam : diffParams) {
52145229
auto paramType = diffParam.getPlainType();
52155230
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
5216-
assert(paramTan && "Parameter has no tangent space?");
5231+
// Error if paraneter has no tangent space.
5232+
if (!paramTan) {
5233+
return llvm::make_error<DerivativeFunctionTypeError>(
5234+
this,
5235+
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5236+
parameterIndices);
5237+
}
52175238
differentialParams.push_back(AnyFunctionType::Param(
52185239
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
52195240
if (diffParam.isInOut())
@@ -5243,7 +5264,13 @@ AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
52435264
for (auto diffParam : diffParams) {
52445265
auto paramType = diffParam.getPlainType();
52455266
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
5246-
assert(paramTan && "Parameter has no tangent space?");
5267+
// Error if paraneter has no tangent space.
5268+
if (!paramTan) {
5269+
return llvm::make_error<DerivativeFunctionTypeError>(
5270+
this,
5271+
DerivativeFunctionTypeError::Kind::NonDifferentiableParameters,
5272+
parameterIndices);
5273+
}
52475274
if (diffParam.isInOut()) {
52485275
hasInoutDiffParameter = true;
52495276
continue;

lib/Sema/TypeCheckAttr.cpp

Lines changed: 71 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,42 +3375,20 @@ DynamicallyReplacedDeclRequest::evaluate(Evaluator &evaluator,
33753375
return nullptr;
33763376
}
33773377

3378-
/// If the given type conforms to `Differentiable` in the given context, returns
3379-
/// the `ProtocolConformanceRef`. Otherwise, returns an invalid
3380-
/// `ProtocolConformanceRef`.
3381-
///
3382-
/// This helper verifies that the `TangentVector` type witness is valid, in case
3383-
/// the conformance has not been fully checked and the type witness cannot be
3384-
/// resolved.
3385-
static ProtocolConformanceRef getDifferentiableConformance(Type type,
3386-
DeclContext *DC) {
3387-
auto &ctx = type->getASTContext();
3388-
auto *differentiableProto =
3389-
ctx.getProtocol(KnownProtocolKind::Differentiable);
3390-
auto conf =
3391-
TypeChecker::conformsToProtocol(type, differentiableProto, DC);
3392-
if (!conf)
3393-
return ProtocolConformanceRef();
3394-
// Try to get the `TangentVector` type witness, in case the conformance has
3395-
// not been fully checked.
3396-
Type tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
3397-
if (tanType.isNull() || tanType->hasError())
3398-
return ProtocolConformanceRef();
3399-
return conf;
3400-
};
3401-
34023378
/// Returns true if the given type conforms to `Differentiable` in the given
34033379
/// contxt. If `tangentVectorEqualsSelf` is true, also check whether the given
34043380
/// type satisfies `TangentVector == Self`.
34053381
static bool conformsToDifferentiable(Type type, DeclContext *DC,
34063382
bool tangentVectorEqualsSelf = false) {
3407-
auto conf = getDifferentiableConformance(type, DC);
3383+
auto &ctx = type->getASTContext();
3384+
auto *differentiableProto =
3385+
ctx.getProtocol(KnownProtocolKind::Differentiable);
3386+
auto conf = TypeChecker::conformsToProtocol(type, differentiableProto, DC);
34083387
if (conf.isInvalid())
34093388
return false;
34103389
if (!tangentVectorEqualsSelf)
34113390
return true;
3412-
auto &ctx = type->getASTContext();
3413-
Type tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
3391+
auto tanType = conf.getTypeWitnessByName(type, ctx.Id_TangentVector);
34143392
return type->isEqual(tanType);
34153393
};
34163394

@@ -4602,67 +4580,81 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46024580
// Set the resolved differentiability parameter indices in the attribute.
46034581
attr->setParameterIndices(resolvedDiffParamIndices);
46044582

4605-
// Get the original semantic result.
4606-
llvm::SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
4607-
autodiff::getFunctionSemanticResultTypes(
4608-
originalFnType, originalResults,
4609-
derivative->getGenericEnvironmentOfContext());
4610-
// Check that original function has at least one semantic result, i.e.
4611-
// that the original semantic result type is not `Void`.
4612-
if (originalResults.empty()) {
4613-
diags
4614-
.diagnose(attr->getLocation(), diag::autodiff_attr_original_void_result,
4615-
derivative->getName())
4616-
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4617-
attr->setInvalid();
4618-
return true;
4619-
}
4620-
// Check that original function does not have multiple semantic results.
4621-
if (originalResults.size() > 1) {
4622-
diags
4623-
.diagnose(attr->getLocation(),
4624-
diag::autodiff_attr_original_multiple_semantic_results)
4625-
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4626-
attr->setInvalid();
4627-
return true;
4628-
}
4629-
auto originalResult = originalResults.front();
4630-
auto originalResultType = originalResult.type;
4631-
// Check that the original semantic result conforms to `Differentiable`.
4632-
auto valueResultConf = getDifferentiableConformance(
4633-
originalResultType, derivative->getDeclContext());
4634-
if (!valueResultConf) {
4635-
diags.diagnose(attr->getLocation(),
4636-
diag::derivative_attr_result_value_not_differentiable,
4637-
valueResultElt.getType());
4583+
// Compute the expected differential/pullback type.
4584+
auto expectedLinearMapTypeOrError =
4585+
originalFnType->getAutoDiffDerivativeFunctionLinearMapType(
4586+
resolvedDiffParamIndices, kind.getLinearMapKind(), lookupConformance,
4587+
/*makeSelfParamFirst*/ true);
4588+
4589+
// Helper for diagnosing derivative function type errors.
4590+
auto errorHandler = [&](const DerivativeFunctionTypeError &error) {
4591+
switch (error.kind) {
4592+
case DerivativeFunctionTypeError::Kind::NoSemanticResults:
4593+
diags
4594+
.diagnose(attr->getLocation(),
4595+
diag::autodiff_attr_original_multiple_semantic_results)
4596+
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4597+
attr->setInvalid();
4598+
return;
4599+
case DerivativeFunctionTypeError::Kind::MultipleSemanticResults:
4600+
diags
4601+
.diagnose(attr->getLocation(),
4602+
diag::autodiff_attr_original_multiple_semantic_results)
4603+
.highlight(attr->getOriginalFunctionName().Loc.getSourceRange());
4604+
attr->setInvalid();
4605+
return;
4606+
case DerivativeFunctionTypeError::Kind::NonDifferentiableParameters: {
4607+
auto *nonDiffParamIndices = error.getNonDifferentiableParameterIndices();
4608+
SmallVector<AnyFunctionType::Param, 4> diffParams;
4609+
error.functionType->getSubsetParameters(resolvedDiffParamIndices,
4610+
diffParams);
4611+
for (unsigned i : range(diffParams.size())) {
4612+
if (!nonDiffParamIndices->contains(i))
4613+
continue;
4614+
SourceLoc loc = parsedDiffParams.empty() ? attr->getLocation()
4615+
: parsedDiffParams[i].getLoc();
4616+
auto diffParamType = diffParams[i].getPlainType();
4617+
diags.diagnose(loc, diag::diff_params_clause_param_not_differentiable,
4618+
diffParamType);
4619+
}
4620+
return;
4621+
}
4622+
case DerivativeFunctionTypeError::Kind::NonDifferentiableResult:
4623+
auto originalResultType = error.getNonDifferentiableResultType();
4624+
diags.diagnose(attr->getLocation(),
4625+
diag::differentiable_attr_result_not_differentiable,
4626+
originalResultType);
4627+
attr->setInvalid();
4628+
return;
4629+
}
4630+
};
4631+
// Diagnose any derivative function type errors.
4632+
if (!expectedLinearMapTypeOrError) {
4633+
auto error = expectedLinearMapTypeOrError.takeError();
4634+
handleAllErrors(std::move(error), errorHandler);
46384635
return true;
46394636
}
4640-
4641-
// Compute the actual differential/pullback type that we use for comparison
4642-
// with the expected type. We must canonicalize the derivative interface type
4643-
// before extracting the differential/pullback type from it, so that the
4644-
// derivative interface type generic signature is available for simplifying
4645-
// types.
4637+
Type expectedLinearMapType = expectedLinearMapTypeOrError.get();
4638+
if (expectedLinearMapType->hasTypeParameter())
4639+
expectedLinearMapType =
4640+
derivative->mapTypeIntoContext(expectedLinearMapType);
4641+
if (expectedLinearMapType->hasArchetype())
4642+
expectedLinearMapType = expectedLinearMapType->mapTypeOutOfContext();
4643+
4644+
// Compute the actual differential/pullback type for comparison with the
4645+
// expected type. We must canonicalize the derivative interface type before
4646+
// extracting the differential/pullback type from it so that types are
4647+
// simplified via the canonical generic signature.
46464648
CanType canActualResultType = derivativeInterfaceType->getCanonicalType();
46474649
while (isa<AnyFunctionType>(canActualResultType)) {
46484650
canActualResultType =
46494651
cast<AnyFunctionType>(canActualResultType).getResult();
46504652
}
4651-
CanType actualFuncEltType =
4653+
CanType actualLinearMapType =
46524654
cast<TupleType>(canActualResultType).getElementType(1);
46534655

4654-
// Compute expected differential/pullback type.
4655-
Type expectedFuncEltType =
4656-
originalFnType->getAutoDiffDerivativeFunctionLinearMapType(
4657-
resolvedDiffParamIndices, kind.getLinearMapKind(), lookupConformance,
4658-
/*makeSelfParamFirst*/ true);
4659-
if (expectedFuncEltType->hasTypeParameter())
4660-
expectedFuncEltType = derivative->mapTypeIntoContext(expectedFuncEltType);
4661-
if (expectedFuncEltType->hasArchetype())
4662-
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext();
4663-
46644656
// Check if differential/pullback type matches expected type.
4665-
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
4657+
if (!actualLinearMapType->isEqual(expectedLinearMapType)) {
46664658
// Emit differential/pullback type mismatch error on attribute.
46674659
diags.diagnose(attr->getLocation(),
46684660
diag::derivative_attr_result_func_type_mismatch,
@@ -4675,7 +4667,7 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
46754667
diags
46764668
.diagnose(funcEltTypeRepr->getStartLoc(),
46774669
diag::derivative_attr_result_func_type_mismatch_note,
4678-
funcResultElt.getName(), expectedFuncEltType)
4670+
funcResultElt.getName(), expectedLinearMapType)
46794671
.highlight(funcEltTypeRepr->getSourceRange());
46804672
// Emit note showing original function location, if possible.
46814673
if (originalAFD->getLoc().isValid())

0 commit comments

Comments
 (0)