Skip to content

Commit ddcb1d5

Browse files
authored
Merge pull request #28853 from dan-zheng/derivative-attr-type-checking
2 parents b64c620 + 1d66571 commit ddcb1d5

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,25 +3292,34 @@ static bool checkFunctionSignature(
32923292
return false;
32933293
}
32943294

3295+
// Map type into the required function type's generic signature, if it exists.
3296+
// This is significant when the required generic signature has same-type
3297+
// requirements while the candidate generic signature does not.
3298+
auto mapType = [&](Type type) {
3299+
if (!requiredGenSig)
3300+
return type->getCanonicalType();
3301+
return requiredGenSig->getCanonicalTypeInContext(type);
3302+
};
3303+
32953304
// Check that parameter types match, disregarding labels.
32963305
if (required->getNumParams() != candidateFnTy->getNumParams())
32973306
return false;
32983307
if (!std::equal(required->getParams().begin(), required->getParams().end(),
32993308
candidateFnTy->getParams().begin(),
3300-
[](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3301-
return x.getPlainType()->isEqual(y.getPlainType());
3309+
[&](AnyFunctionType::Param x, AnyFunctionType::Param y) {
3310+
return x.getPlainType()->isEqual(mapType(y.getPlainType()));
33023311
}))
33033312
return false;
33043313

33053314
// If required result type is not a function type, check that result types
33063315
// match exactly.
33073316
auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult());
3317+
auto candidateResultTy = mapType(candidateFnTy.getResult());
33083318
if (!requiredResultFnTy) {
33093319
auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult());
3310-
auto candidateResultTupleTy =
3311-
dyn_cast<TupleType>(candidateFnTy.getResult());
3320+
auto candidateResultTupleTy = dyn_cast<TupleType>(candidateResultTy);
33123321
if (!requiredResultTupleTy || !candidateResultTupleTy)
3313-
return required.getResult()->isEqual(candidateFnTy.getResult());
3322+
return required.getResult()->isEqual(candidateResultTy);
33143323
// If result types are tuple types, check that element types match,
33153324
// ignoring labels.
33163325
if (requiredResultTupleTy->getNumElements() !=
@@ -3323,7 +3332,7 @@ static bool checkFunctionSignature(
33233332
}
33243333

33253334
// Required result type is a function. Recurse.
3326-
return checkFunctionSignature(requiredResultFnTy, candidateFnTy.getResult());
3335+
return checkFunctionSignature(requiredResultFnTy, candidateResultTy);
33273336
};
33283337

33293338
// Returns an `AnyFunctionType` with the same `ExtInfo` as `fnType`, but with
@@ -3607,8 +3616,20 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36073616
auto resultTanType = valueResultConf.getTypeWitnessByName(
36083617
valueResultType, Ctx.Id_TangentVector);
36093618

3619+
// Compute the actual differential/pullback type that we use for comparison
3620+
// with the expected type. We must canonicalize the derivative interface type
3621+
// before extracting the differential/pullback type from it, so that the
3622+
// derivative interface type generic signature is available for simplifying
3623+
// types.
3624+
CanType canActualResultType = derivativeInterfaceType->getCanonicalType();
3625+
while (isa<AnyFunctionType>(canActualResultType)) {
3626+
canActualResultType =
3627+
cast<AnyFunctionType>(canActualResultType).getResult();
3628+
}
3629+
CanType actualFuncEltType =
3630+
cast<TupleType>(canActualResultType).getElementType(1);
3631+
36103632
// Compute expected differential/pullback type.
3611-
auto funcEltType = funcResultElt.getType();
36123633
Type expectedFuncEltType;
36133634
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
36143635
auto diffParams = map<SmallVector<AnyFunctionType::Param, 4>>(
@@ -3624,7 +3645,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) {
36243645
expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext();
36253646

36263647
// Check if differential/pullback type matches expected type.
3627-
if (!funcEltType->isEqual(expectedFuncEltType)) {
3648+
if (!actualFuncEltType->isEqual(expectedFuncEltType)) {
36283649
// Emit differential/pullback type mismatch error on attribute.
36293650
diagnoseAndRemoveAttr(attr, diag::derivative_attr_result_func_type_mismatch,
36303651
funcResultElt.getName(), originalAFD->getFullName());

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ func req1<T>(_ x: T) -> T {
276276
return x
277277
}
278278
@derivative(of: req1)
279-
func vjpReq1<T: Differentiable>(_ x: T) -> (
279+
func vjpExtraConformanceConstraint<T: Differentiable>(_ x: T) -> (
280280
value: T, pullback: (T.TangentVector) -> T.TangentVector
281281
) {
282282
return (x, { $0 })
@@ -286,12 +286,42 @@ func req2<T, U>(_ x: T, _ y: U) -> T {
286286
return x
287287
}
288288
@derivative(of: req2)
289-
func vjpReq2<T: Differentiable, U: Differentiable>(_ x: T, _ y: U)
290-
-> (value: T, pullback: (T) -> (T, U))
291-
where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
289+
func vjpExtraConformanceConstraints<T: Differentiable, U: Differentiable>( _ x: T, _ y: U) -> (
290+
value: T, pullback: (T) -> (T, U)
291+
) where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible {
292292
return (x, { ($0, .zero) })
293293
}
294294

295+
// Test `@derivative` declaration with extra same-type requirements.
296+
func req3<T>(_ x: T) -> T {
297+
return x
298+
}
299+
@derivative(of: req3)
300+
func vjpSameTypeRequirementsGenericParametersAllConcrete<T>(_ x: T) -> (
301+
value: T, pullback: (T.TangentVector) -> T.TangentVector
302+
) where T: Differentiable, T.TangentVector == Float {
303+
return (x, { $0 })
304+
}
305+
306+
struct Wrapper<T: Equatable>: Equatable {
307+
var x: T
308+
init(_ x: T) { self.x = x }
309+
}
310+
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
311+
static var zero: Self { .init(.zero) }
312+
static func + (lhs: Self, rhs: Self) -> Self { .init(lhs.x + rhs.x) }
313+
static func - (lhs: Self, rhs: Self) -> Self { .init(lhs.x - rhs.x) }
314+
}
315+
extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector {
316+
typealias TangentVector = Wrapper<T.TangentVector>
317+
}
318+
extension Wrapper where T: Differentiable, T == T.TangentVector {
319+
@derivative(of: init(_:))
320+
static func vjpInit(_ x: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) {
321+
fatalError()
322+
}
323+
}
324+
295325
// Test class methods.
296326

297327
class Super {

0 commit comments

Comments
 (0)