Skip to content

Commit 1d66571

Browse files
committed
[AutoDiff upstream] @derivative attribute type-checking fixes.
Upstream `@derivative` attribute type-checking fixes regarding derivative generic signatures with all concrete generic parameters. Cherry-picked from: - swiftlang#28762 - swiftlang#28772
1 parent bac5a64 commit 1d66571

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3263,25 +3263,34 @@ static bool checkFunctionSignature(
32633263
return false;
32643264
}
32653265

3266+
// Map type into the required function type's generic signature, if it exists.
3267+
// This is significant when the required generic signature has same-type
3268+
// requirements while the candidate generic signature does not.
3269+
auto mapType = [&](Type type) {
3270+
if (!requiredGenSig)
3271+
return type->getCanonicalType();
3272+
return requiredGenSig->getCanonicalTypeInContext(type);
3273+
};
3274+
32663275
// Check that parameter types match, disregarding labels.
32673276
if (required->getNumParams() != candidateFnTy->getNumParams())
32683277
return false;
32693278
if (!std::equal(required->getParams().begin(), required->getParams().end(),
32703279
candidateFnTy->getParams().begin(),
3271-
[](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3272-
return x.getPlainType()->isEqual(y.getPlainType());
3280+
[&](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3281+
return x.getPlainType()->isEqual(mapType(y.getPlainType()));
32733282
}))
32743283
return false;
32753284

32763285
// If required result type is not a function type, check that result types
32773286
// match exactly.
32783287
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult());
3288+
auto candidateResultTy = mapType(candidateFnTy.getResult());
32793289
if (!requiredResultFnTy) {
32803290
auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult());
3281-
auto candidateResultTupleTy =
3282-
dyn_cast<TupleType>(candidateFnTy.getResult());
3291+
auto candidateResultTupleTy = dyn_cast<TupleType>(candidateResultTy);
32833292
if (!requiredResultTupleTy || !candidateResultTupleTy)
3284-
return required.getResult()->isEqual(candidateFnTy.getResult());
3293+
return required.getResult()->isEqual(candidateResultTy);
32853294
// If result types are tuple types, check that element types match,
32863295
// ignoring labels.
32873296
if (requiredResultTupleTy->getNumElements() !=
@@ -3294,7 +3303,7 @@ static bool checkFunctionSignature(
32943303
}
32953304

32963305
// Required result type is a function. Recurse.
3297-
return checkFunctionSignature(requiredResultFnTy, candidateFnTy.getResult());
3306+
return checkFunctionSignature(requiredResultFnTy, candidateResultTy);
32983307
};
32993308

33003309
// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with
@@ -3578,8 +3587,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
35783587
auto resultTanType = valueResultConf.getTypeWitnessByName(
35793588
valueResultType, Ctx.Id_TangentVector);
35803589

3590+
// Compute the actual differential/pullback type that we use for comparison
3591+
// with the expected type. We must canonicalize the derivative interface type
3592+
// before extracting the differential/pullback type from it, so that the
3593+
// derivative interface type generic signature is available for simplifying
3594+
// types.
3595+
CanType canActualResultType = derivativeInterfaceType->getCanonicalType();
3596+
while (isa<AnyFunctionType>(canActualResultType)) {
3597+
canActualResultType =
3598+
cast<AnyFunctionType>(canActualResultType).getResult();
3599+
}
3600+
CanType actualFuncEltType =
3601+
cast<TupleType>(canActualResultType).getElementType(1);
3602+
35813603
// Compute expected differential/pullback type.
3582-
auto funcEltType = funcResultElt.getType();
35833604
Type expectedFuncEltType;
35843605
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
35853606
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4>>(
@@ -3595,7 +3616,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
35953616
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext();
35963617

35973618
// Check if differential/pullback type matches expected type.
3598-
if (!funcEltType->isEqual(expectedFuncEltType)) {
3619+
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
35993620
// Emit differential/pullback type mismatch error on attribute.
36003621
diagnoseAndRemoveAttr(attr, diag::derivative_attr_result_func_type_mismatch,
36013622
funcResultElt.getName(), originalAFD->getFullName());

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func req1<T>(_ x: T) -> T {
276276
return x
277277
}
278278
@derivative(of: req1)
279-
func vjpReq1<T: Differentiable>(_ x: T) -> (
279+
func vjpExtraConformanceConstraint<T: Differentiable>(_ x: T) -> (
280280
value: T, pullback: (T.TangentVector) -> T.TangentVector
281281
) {
282282
return (x, { $0 })
@@ -286,12 +286,42 @@ func req2<T, U>(_ x: T, _ y: U) -> T {
286286
return x
287287
}
288288
@derivative(of: req2)
289-
func vjpReq2<T: Differentiable, U: Differentiable>(_ x: T, _ y: U)
290-
-> (value: T, pullback: (T) -> (T, U))
291-
where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
289+
func vjpExtraConformanceConstraints<T: Differentiable, U: Differentiable>( _ x: T, _ y: U) -> (
290+
value: T, pullback: (T) -> (T, U)
291+
) where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
292292
return (x, { ($0, .zero) })
293293
}
294294

295+
// Test `@derivative` declaration with extra same-type requirements.
296+
func req3<T>(_ x: T) -> T {
297+
return x
298+
}
299+
@derivative(of: req3)
300+
func vjpSameTypeRequirementsGenericParametersAllConcrete<T>(_ x: T) -> (
301+
value: T, pullback: (T.TangentVector) -> T.TangentVector
302+
) where T: Differentiable, T.TangentVector == Float {
303+
return (x, { $0 })
304+
}
305+
306+
struct Wrapper<T: Equatable>: Equatable {
307+
var x: T
308+
init(_ x: T) { self.x = x }
309+
}
310+
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
311+
static var zero: Self { .init(.zero) }
312+
static func + (lhs: Self, rhs: Self) -> Self { .init(lhs.x + rhs.x) }
313+
static func - (lhs: Self, rhs: Self) -> Self { .init(lhs.x - rhs.x) }
314+
}
315+
extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector {
316+
typealias TangentVector = Wrapper<T.TangentVector>
317+
}
318+
extension Wrapper where T: Differentiable, T == T.TangentVector {
319+
@derivative(of: init(_:))
320+
static func vjpInit(_ x: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) {
321+
fatalError()
322+
}
323+
}
324+
295325
// Test class methods.
296326

297327
class Super {

0 commit comments

Comments
 (0)