Skip to content

[AutoDiff upstream] @derivative attribute type-checking fixes. #28853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3263,25 +3263,34 @@ static bool checkFunctionSignature(
return false;
}

// Map type into the required function type's generic signature, if it exists.
// This is significant when the required generic signature has same-type
// requirements while the candidate generic signature does not.
auto mapType = [&](Type type) {
if (!requiredGenSig)
return type->getCanonicalType();
return requiredGenSig->getCanonicalTypeInContext(type);
};

// Check that parameter types match, disregarding labels.
if (required->getNumParams() != candidateFnTy->getNumParams())
return false;
if (!std::equal(required->getParams().begin(), required->getParams().end(),
candidateFnTy->getParams().begin(),
[](AnyFunctionType::Param x, AnyFunctionType::Param y) {
return x.getPlainType()->isEqual(y.getPlainType());
[&](AnyFunctionType::Param x, AnyFunctionType::Param y) {
return x.getPlainType()->isEqual(mapType(y.getPlainType()));
}))
return false;

// If required result type is not a function type, check that result types
// match exactly.
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult());
auto candidateResultTy = mapType(candidateFnTy.getResult());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to remap candidateFnTy once, instead of remapping its individual components.
Filed TF-1073 to track this improvement.

if (!requiredResultFnTy) {
auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult());
auto candidateResultTupleTy =
dyn_cast<TupleType>(candidateFnTy.getResult());
auto candidateResultTupleTy = dyn_cast<TupleType>(candidateResultTy);
if (!requiredResultTupleTy || !candidateResultTupleTy)
return required.getResult()->isEqual(candidateFnTy.getResult());
return required.getResult()->isEqual(candidateResultTy);
// If result types are tuple types, check that element types match,
// ignoring labels.
if (requiredResultTupleTy->getNumElements() !=
Expand All @@ -3294,7 +3303,7 @@ static bool checkFunctionSignature(
}

// Required result type is a function. Recurse.
return checkFunctionSignature(requiredResultFnTy, candidateFnTy.getResult());
return checkFunctionSignature(requiredResultFnTy, candidateResultTy);
};

// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with
Expand Down Expand Up @@ -3578,8 +3587,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
auto resultTanType = valueResultConf.getTypeWitnessByName(
valueResultType, Ctx.Id_TangentVector);

// Compute the actual differential/pullback type that we use for comparison
// with the expected type. We must canonicalize the derivative interface type
// before extracting the differential/pullback type from it, so that the
// derivative interface type generic signature is available for simplifying
// types.
CanType canActualResultType = derivativeInterfaceType->getCanonicalType();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to canonicalize derivativeInterfaceType earlier, like:

-  auto *derivativeInterfaceType = derivative->getInterfaceType()
-      ->castTo<AnyFunctionType>();
+  auto derivativeInterfaceType = dyn_cast<AnyFunctionType>(
+      derivative->getInterfaceType()->getCanonicalType());

I tried this on tensorflow branch but it didn't work out of the box:

/Users/danielzheng/swift-tf/swift/test/AutoDiff/derivative_attr_type_checking.swift:58:25: error: incorrect message found
Assertion failed: ((size_t)sys::locale::columnWidth(I->getText()) == I->getText().size()), function buildFixItLine, file /Users/danielzheng/swift-tf/llvm-project/llvm/lib/Support/SourceMgr.cpp, line 316.

I believe this error occurs when diagnostic messages contain archetypes (τ_0_0). Filed TF-1073 to track this improvement.

while (isa<AnyFunctionType>(canActualResultType)) {
canActualResultType =
cast<AnyFunctionType>(canActualResultType).getResult();
}
CanType actualFuncEltType =
cast<TupleType>(canActualResultType).getElementType(1);

// Compute expected differential/pullback type.
auto funcEltType = funcResultElt.getType();
Type expectedFuncEltType;
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4>>(
Expand All @@ -3595,7 +3616,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext();

// Check if differential/pullback type matches expected type.
if (!funcEltType->isEqual(expectedFuncEltType)) {
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
// Emit differential/pullback type mismatch error on attribute.
diagnoseAndRemoveAttr(attr, diag::derivative_attr_result_func_type_mismatch,
funcResultElt.getName(), originalAFD->getFullName());
Expand Down
38 changes: 34 additions & 4 deletions test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func req1<T>(_ x: T) -> T {
return x
}
@derivative(of: req1)
func vjpReq1<T: Differentiable>(_ x: T) -> (
func vjpExtraConformanceConstraint<T: Differentiable>(_ x: T) -> (
value: T, pullback: (T.TangentVector) -> T.TangentVector
) {
return (x, { $0 })
Expand All @@ -286,12 +286,42 @@ func req2<T, U>(_ x: T, _ y: U) -> T {
return x
}
@derivative(of: req2)
func vjpReq2<T: Differentiable, U: Differentiable>(_ x: T, _ y: U)
-> (value: T, pullback: (T) -> (T, U))
where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
func vjpExtraConformanceConstraints<T: Differentiable, U: Differentiable>( _ x: T, _ y: U) -> (
value: T, pullback: (T) -> (T, U)
) where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
return (x, { ($0, .zero) })
}

// Test `@derivative` declaration with extra same-type requirements.
func req3<T>(_ x: T) -> T {
return x
}
@derivative(of: req3)
func vjpSameTypeRequirementsGenericParametersAllConcrete<T>(_ x: T) -> (
value: T, pullback: (T.TangentVector) -> T.TangentVector
) where T: Differentiable, T.TangentVector == Float {
return (x, { $0 })
}

struct Wrapper<T: Equatable>: Equatable {
var x: T
init(_ x: T) { self.x = x }
}
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
static var zero: Self { .init(.zero) }
static func + (lhs: Self, rhs: Self) -> Self { .init(lhs.x + rhs.x) }
static func - (lhs: Self, rhs: Self) -> Self { .init(lhs.x - rhs.x) }
}
extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector {
typealias TangentVector = Wrapper<T.TangentVector>
}
extension Wrapper where T: Differentiable, T == T.TangentVector {
@derivative(of: init(_:))
static func vjpInit(_ x: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) {
fatalError()
}
}

// Test class methods.

class Super {
Expand Down