Skip to content

Commit d10f51a

Browse files
authored
[AutoDiff] Fix @differentiable(linear) type-checking. (#28927)
For all non-`@noDerivative` parameter and result types, `@differentiable(linear)` function types should require and imply `T: Differentiable`, `T == T.TangentVector` requirements instead of `T: Differentiable & AdditiveArithmetic`. Emit `@noDerivative` fixit only when there is at least one valid differentiability/linearity parameter. Otherwise, adding `@noDerivative` produces an ill-formed `@differentiable` function type. Update tests.
1 parent 779e35f commit d10f51a

File tree

5 files changed

+168
-62
lines changed

5 files changed

+168
-62
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4094,13 +4094,18 @@ ERROR(attr_only_on_parameters_of_differentiable,none,
40944094
"'%0' may only be used on parameters of '@differentiable' function "
40954095
"types", (StringRef))
40964096
// SWIFT_ENABLE_TENSORFLOW
4097-
ERROR(autodiff_attr_argument_not_differentiable,none,
4098-
"argument is not differentiable, but the enclosing function type is "
4099-
"marked '@differentiable'; did you want to add '@noDerivative' to this "
4100-
"argument?", ())
4101-
ERROR(autodiff_attr_result_not_differentiable,none,
4102-
"result is not differentiable, but the function type is marked "
4103-
"'@differentiable'", ())
4097+
ERROR(differentiable_function_type_invalid_parameter,none,
4098+
"parameter type '%0' does not conform to 'Differentiable'"
4099+
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4100+
"function type is '@differentiable%select{|(linear)}1'"
4101+
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
4102+
(StringRef, /*tangentVectorEqualsSelf*/ bool,
4103+
/*hasValidDifferentiabilityParameter*/ bool))
4104+
ERROR(differentiable_function_type_invalid_result,none,
4105+
"result type '%0' does not conform to 'Differentiable'"
4106+
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
4107+
"function type is '@differentiable%select{|(linear)}1'",
4108+
(StringRef, bool))
41044109
ERROR(attr_differentiable_no_vjp_or_jvp_when_linear,none,
41054110
"cannot specify 'vjp:' or 'jvp:' for linear functions; use "
41064111
"'transpose:' instead", ())

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5073,31 +5073,47 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
50735073
}
50745074

50755075
// SWIFT_ENABLE_TENSORFLOW
5076-
// Infer `Differentiable` or `Differentiable & AdditiveArithmetic` generic
5077-
// constraints from `@differentiable` or `@differentiable(linear)`.
5076+
// Infer requirements from `@differentiable` or `@differentiable(linear)`
5077+
// function types.
5078+
// For all non-`@noDerivative` parameter and result types:
5079+
// - `@differentiable`: add `T: Differentiable` requirement.
5080+
// - `@differentiable(linear)`: add
5081+
// `T: Differentiable`, `T == T.TangentVector` requirements.
50785082
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
50795083
if (fnTy->isDifferentiable()) {
5080-
auto addConstraint = [&](Type typeToConstrain, ProtocolDecl *protocol) {
5081-
Requirement req(RequirementKind::Conformance, typeToConstrain,
5084+
auto addConformanceConstraint = [&](Type type, ProtocolDecl *protocol) {
5085+
Requirement req(RequirementKind::Conformance, type,
50825086
protocol->getDeclaredType());
50835087
Builder.addRequirement(req, source, nullptr);
50845088
};
5085-
auto constrainParametersAndResult = [&](ProtocolDecl *protocol) {
5089+
auto addSameTypeConstraint = [&](Type firstType,
5090+
AssociatedTypeDecl *assocType) {
5091+
auto *protocol = assocType->getProtocol();
5092+
auto conf = Builder.lookupConformance(CanType(), firstType, protocol);
5093+
auto secondType = conf.getAssociatedType(
5094+
firstType, assocType->getDeclaredInterfaceType());
5095+
Requirement req(RequirementKind::SameType, firstType, secondType);
5096+
Builder.addRequirement(req, source, nullptr);
5097+
};
5098+
auto &ctx = Builder.getASTContext();
5099+
auto *differentiableProtocol =
5100+
ctx.getProtocol(KnownProtocolKind::Differentiable);
5101+
auto *tangentVectorAssocType =
5102+
differentiableProtocol->getAssociatedType(ctx.Id_TangentVector);
5103+
auto addRequirements = [&](Type type, bool isLinear) {
5104+
addConformanceConstraint(type, differentiableProtocol);
5105+
if (isLinear)
5106+
addSameTypeConstraint(type, tangentVectorAssocType);
5107+
};
5108+
auto constrainParametersAndResult = [&](bool isLinear) {
50865109
for (auto &param : fnTy->getParams())
50875110
if (!param.isNoDerivative())
5088-
addConstraint(param.getPlainType(), protocol);
5089-
addConstraint(fnTy->getResult(), protocol);
5111+
addRequirements(param.getPlainType(), isLinear);
5112+
addRequirements(fnTy->getResult(), isLinear);
50905113
};
5091-
// Add `Differentiable` constraints.
5092-
constrainParametersAndResult(
5093-
Builder.getASTContext()
5094-
.getProtocol(KnownProtocolKind::Differentiable));
5095-
// Add `AdditiveArithmetic` constraints if the function is linear.
5096-
if (fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
5097-
constrainParametersAndResult(
5098-
Builder.getASTContext()
5099-
.getProtocol(KnownProtocolKind::AdditiveArithmetic));
5100-
}
5114+
// Add requirements.
5115+
constrainParametersAndResult(fnTy->getDifferentiabilityKind() ==
5116+
DifferentiabilityKind::Linear);
51015117
}
51025118
}
51035119

lib/Sema/TypeCheckType.cpp

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,10 @@ namespace {
18631863
TypeResolutionOptions options);
18641864

18651865
// SWIFT_ENABLE_TENSORFLOW
1866-
bool isDifferentiableType(Type ty);
1866+
/// Returns true if the given type conforms to `Differentiable` in the
1867+
/// module of `DC`. If `tangentVectorEqualsSelf` is true, returns true iff
1868+
/// the given type additionally satisfies `Self == Self.TangentVector`.
1869+
bool isDifferentiable(Type type, bool tangentVectorEqualsSelf = false);
18671870
};
18681871
} // end anonymous namespace
18691872

@@ -2578,23 +2581,48 @@ bool TypeResolver::resolveASTFunctionTypeParams(
25782581
// SWIFT_ENABLE_TENSORFLOW END
25792582
}
25802583

2581-
// SWIFT_ENABLE_TENSORFLOW
2582-
if (diffKind != DifferentiabilityKind::NonDifferentiable &&
2583-
resolution.getStage() != TypeResolutionStage::Structural) {
2584-
if (!noDerivative && !isDifferentiableType(ty)) {
2585-
diagnose(eltTypeRepr->getLoc(),
2586-
diag::autodiff_attr_argument_not_differentiable)
2587-
.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
2588-
}
2589-
}
2590-
// SWIFT_ENABLE_TENSORFLOW END
2591-
25922584
auto paramFlags = ParameterTypeFlags::fromParameterType(
25932585
ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership,
25942586
noDerivative);
25952587
elements.emplace_back(ty, Identifier(), paramFlags);
25962588
}
25972589

2590+
// SWIFT_ENABLE_TENSORFLOW
2591+
// All non-`@noDerivative` parameters of `@differentiable` and
2592+
// `@differentiable(linear)` function types must be differentiable.
2593+
if (diffKind != DifferentiabilityKind::NonDifferentiable &&
2594+
resolution.getStage() != TypeResolutionStage::Structural) {
2595+
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2596+
// Emit `@noDerivative` fixit only if there is at least one valid
2597+
// differentiability/linearity parameter. Otherwise, adding `@noDerivative`
2598+
// produces an ill-formed function type.
2599+
auto hasValidDifferentiabilityParam =
2600+
llvm::find_if(elements, [&](AnyFunctionType::Param param) {
2601+
if (param.isNoDerivative())
2602+
return false;
2603+
return isDifferentiable(param.getPlainType(),
2604+
/*tangentVectorEqualsSelf*/ isLinear);
2605+
}) != elements.end();
2606+
// });
2607+
for (unsigned i = 0, end = inputRepr->getNumElements(); i != end; ++i) {
2608+
auto *eltTypeRepr = inputRepr->getElementType(i);
2609+
auto param = elements[i];
2610+
if (param.isNoDerivative())
2611+
continue;
2612+
auto paramType = param.getPlainType();
2613+
if (isDifferentiable(paramType, /*tangentVectorEqualsSelf*/ isLinear))
2614+
continue;
2615+
auto paramTypeString = paramType->getString();
2616+
auto diagnostic =
2617+
diagnose(eltTypeRepr->getLoc(),
2618+
diag::differentiable_function_type_invalid_parameter,
2619+
paramTypeString, isLinear, hasValidDifferentiabilityParam);
2620+
if (hasValidDifferentiabilityParam)
2621+
diagnostic.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
2622+
}
2623+
}
2624+
// SWIFT_ENABLE_TENSORFLOW END
2625+
25982626
return false;
25992627
}
26002628

@@ -2717,30 +2745,38 @@ Type TypeResolver::resolveASTFunctionType(
27172745
}
27182746

27192747
// SWIFT_ENABLE_TENSORFLOW
2720-
// If the function is marked as `@differentiable`, the result must be a
2721-
// differentiable type.
2748+
// `@differentiable` and `@differentiable(linear)` function types must return
2749+
// a differentiable type.
27222750
if (extInfo.isDifferentiable() &&
27232751
resolution.getStage() != TypeResolutionStage::Structural) {
2724-
if (!isDifferentiableType(outputTy)) {
2752+
bool isLinear = diffKind == DifferentiabilityKind::Linear;
2753+
if (!isDifferentiable(outputTy, /*tangentVectorEqualsSelf*/ isLinear)) {
27252754
diagnose(repr->getResultTypeRepr()->getLoc(),
2726-
diag::autodiff_attr_result_not_differentiable)
2755+
diag::differentiable_function_type_invalid_result,
2756+
outputTy->getString(), isLinear)
27272757
.highlight(repr->getResultTypeRepr()->getSourceRange());
27282758
}
27292759
}
2760+
// SWIFT_ENABLE_TENSORFLOW END
27302761

27312762
return fnTy;
27322763
}
27332764

27342765
// SWIFT_ENABLE_TENSORFLOW
2735-
bool TypeResolver::isDifferentiableType(Type ty) {
2736-
if (resolution.getStage() != TypeResolutionStage::Contextual) {
2737-
ty = DC->mapTypeIntoContext(ty);
2738-
}
2739-
return ty
2740-
->getAutoDiffAssociatedTangentSpace(
2741-
LookUpConformanceInModule(DC->getParentModule()))
2742-
.hasValue();
2766+
bool TypeResolver::isDifferentiable(Type type, bool tangentVectorEqualsSelf) {
2767+
if (resolution.getStage() != TypeResolutionStage::Contextual)
2768+
type = DC->mapTypeIntoContext(type);
2769+
auto tanSpace = type->getAutoDiffAssociatedTangentSpace(
2770+
LookUpConformanceInModule(DC->getParentModule()));
2771+
if (!tanSpace)
2772+
return false;
2773+
// If no `Self == Self.TangentVector` requirement, return true.
2774+
if (!tangentVectorEqualsSelf)
2775+
return true;
2776+
// Otherwise, return true if `Self == Self.TangentVector`.
2777+
return type->getCanonicalType() == tanSpace->getCanonicalType();
27432778
}
2779+
// SWIFT_ENABLE_TENSORFLOW END
27442780

27452781
Type TypeResolver::resolveSILBoxType(SILBoxTypeRepr *repr,
27462782
TypeResolutionOptions options) {

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ extension TF_521 where T: Differentiable, T == T.TangentVector {
776776
return (TF_521(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
777777
}
778778
}
779-
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
779+
// expected-error @+1 {{result type 'TF_521<Float>' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
780780
let _: @differentiable(Float, Float) -> TF_521<Float> = { r, i in
781781
TF_521(real: r, imaginary: i)
782782
}

test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,28 @@ let _: @differentiable (Float) throws -> Float
88
//===----------------------------------------------------------------------===//
99

1010
struct NonDiffType { var x: Int }
11+
1112
// FIXME: Properly type-check parameters and the result's differentiability
12-
// expected-error @+1 {{argument is not differentiable, but the enclosing function type is marked '@differentiable'}} {{25-25=@noDerivative }}
13+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1314
let _: @differentiable (NonDiffType) -> Float
14-
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
15+
16+
// Emit `@noDerivative` fixit iff there is at least one valid differentiability parameter.
17+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'; did you want to add '@noDerivative' to this parameter?}} {{32-32=@noDerivative }}
18+
let _: @differentiable (Float, NonDiffType) -> Float
19+
20+
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
21+
let _: @differentiable(linear) (Float) -> NonDiffType
22+
23+
// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
24+
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?}} {{40-40=@noDerivative }}
25+
let _: @differentiable(linear) (Float, NonDiffType) -> Float
26+
27+
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1528
let _: @differentiable (Float) -> NonDiffType
1629

30+
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
31+
let _: @differentiable(linear) (Float) -> NonDiffType
32+
1733
let _: @differentiable(linear) (Float) -> Float
1834

1935
//===----------------------------------------------------------------------===//
@@ -80,10 +96,10 @@ func foo<T: Differentiable, U: Differentiable>(x: T) -> U {
8096

8197
func test1<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Float) {}
8298
func test2<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Float) {}
83-
// expected-error @+2 {{result is not differentiable, but the function type is marked '@differentiable'}}
84-
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
99+
// expected-error @+2 {{result type 'Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
100+
// expected-error @+1 {{result type '@differentiable (U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
85101
func test3<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Int) {}
86-
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
102+
// expected-error @+1 {{result type '(U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
87103
func test4<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Int) {}
88104

89105
let diffFunc: @differentiable (Float) -> Float
@@ -107,23 +123,56 @@ struct Vector<T> {
107123
var x, y: T
108124
}
109125
extension Vector: Equatable where T: Equatable {}
110-
extension Vector: Differentiable where T: Differentiable {}
111126
extension Vector: AdditiveArithmetic where T: AdditiveArithmetic {}
127+
extension Vector: Differentiable where T: Differentiable {}
112128

113-
// expected-note @+1 {{where 'T' = 'Int'}}
129+
// expected-note @+1 2 {{where 'T' = 'Int'}}
114130
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
115131

116-
// expected-note @+1 {{where 'T' = 'Int'}}
132+
// expected-error @+5 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
133+
// expected-error @+4 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
134+
// expected-error @+3 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
135+
// expected-error @+2 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
136+
// expected-note @+1 2 {{where 'T' = 'Int'}}
117137
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Vector<T>) -> Vector<U>) {}
118138

119-
func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
139+
func nondiff(x: Vector<Int>) -> Vector<Int> {}
120140
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
121-
inferredConformancesGeneric(nondiffVectorFunc)
141+
inferredConformancesGeneric(nondiff)
122142
// expected-error @+1 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable}}
123-
inferredConformancesGenericLinear(nondiffVectorFunc)
143+
inferredConformancesGenericLinear(nondiff)
124144

125-
func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
126-
inferredConformancesGeneric(diffVectorFunc) // okay!
145+
func diff(x: Vector<Float>) -> Vector<Float> {}
146+
inferredConformancesGeneric(diff) // okay!
127147

128148
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}
149+
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
150+
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
151+
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
152+
// expected-error @+1 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
129153
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Vector<T>) -> Vector<U> {}
154+
155+
struct Linear<T> {
156+
var x, y: T
157+
}
158+
extension Linear: Equatable where T: Equatable {}
159+
extension Linear: AdditiveArithmetic where T: AdditiveArithmetic {}
160+
extension Linear: Differentiable where T: Differentiable, T == T.TangentVector {
161+
typealias TangentVector = Self
162+
}
163+
164+
func inferredConformancesGeneric<T, U>(_: @differentiable (Linear<T>) -> Linear<U>) {}
165+
166+
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Linear<T>) -> Linear<U>) {}
167+
168+
func nondiff(x: Linear<Int>) -> Linear<Int> {}
169+
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
170+
inferredConformancesGeneric(nondiff)
171+
// expected-error @+1 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable}}
172+
inferredConformancesGenericLinear(nondiff)
173+
174+
func diff(x: Linear<Float>) -> Linear<Float> {}
175+
inferredConformancesGeneric(diff) // okay!
176+
177+
func inferredConformancesGenericResult<T, U>() -> @differentiable (Linear<T>) -> Linear<U> {}
178+
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Linear<T>) -> Linear<U> {}

0 commit comments

Comments
 (0)