Skip to content

Commit 0e8a67f

Browse files
authored
[AutoDiff] Imply 'AdditiveArithmetic' generic constraints from '@differentiable(linear)' parameters or results. (#28472)
Update `GenericSignatureBuilder` to match the behavior specified in the manifesto: https://github.com/apple/swift/blob/master/docs/DifferentiableProgramming.md#implied-generic-constraints Previously, #24896 added support for implied `Differentiable` constraints from `@differentiable` parameters or results.
1 parent 9853570 commit 0e8a67f

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5069,19 +5069,31 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
50695069
}
50705070

50715071
// SWIFT_ENABLE_TENSORFLOW
5072+
// Infer `Differentiable` or `Differentiable & AdditiveArithmetic` generic
5073+
// constraints from `@differentiable` or `@differentiable(linear)`.
50725074
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
5073-
if (fnTy->getExtInfo().isDifferentiable()) {
5074-
auto *diffableProto = Builder.getASTContext()
5075-
.getProtocol(KnownProtocolKind::Differentiable);
5076-
auto constrainToDifferentiable = [&](Type typeToConstrain) {
5075+
if (fnTy->isDifferentiable()) {
5076+
auto addConstraint = [&](Type typeToConstrain, ProtocolDecl *protocol) {
50775077
Requirement req(RequirementKind::Conformance, typeToConstrain,
5078-
diffableProto->getDeclaredType());
5078+
protocol->getDeclaredType());
50795079
Builder.addRequirement(req, source, nullptr);
50805080
};
5081-
for (auto &param : fnTy->getParams())
5082-
if (!param.isNonDifferentiable())
5083-
constrainToDifferentiable(param.getPlainType());
5084-
constrainToDifferentiable(fnTy->getResult());
5081+
auto constrainParametersAndResult = [&](ProtocolDecl *protocol) {
5082+
for (auto &param : fnTy->getParams())
5083+
if (!param.isNonDifferentiable())
5084+
addConstraint(param.getPlainType(), protocol);
5085+
addConstraint(fnTy->getResult(), protocol);
5086+
};
5087+
// Add `Differentiable` constraints.
5088+
constrainParametersAndResult(
5089+
Builder.getASTContext()
5090+
.getProtocol(KnownProtocolKind::Differentiable));
5091+
// Add `AdditiveArithmetic` constraints if the function is linear.
5092+
if (fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
5093+
constrainParametersAndResult(
5094+
Builder.getASTContext()
5095+
.getProtocol(KnownProtocolKind::AdditiveArithmetic));
5096+
}
50855097
}
50865098
}
50875099

test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,43 @@ func test3<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @diff
8484
func test4<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Int) {}
8585

8686
let diffFunc: @differentiable (Float) -> Float
87+
let linearFunc: @differentiable(linear) (Float) -> Float
8788
func inferredConformances<T, U>(_: @differentiable (T) -> U) {}
89+
func inferredConformancesLinear<T, U>(_: @differentiable(linear) (T) -> U) {}
8890
inferredConformances(diffFunc)
91+
inferredConformancesLinear(linearFunc)
8992

9093
func inferredConformancesResult<T, U>() -> @differentiable (T) -> U {}
94+
func inferredConformancesResultLinear<T, U>() -> @differentiable(linear) (T) -> U {}
9195

9296
let diffFuncWithNondiff: @differentiable (Float, @nondiff Int) -> Float
97+
let linearFuncWithNondiff: @differentiable(linear) (Float, @nondiff Int) -> Float
9398
func inferredConformances<T, U, V>(_: @differentiable (T, @nondiff U) -> V) {}
99+
func inferredConformancesLinear<T, U, V>(_: @differentiable(linear) (T, @nondiff U) -> V) {}
94100
inferredConformances(diffFuncWithNondiff)
101+
inferredConformancesLinear(linearFuncWithNondiff)
95102

96103
struct Vector<T> {
97104
var x, y: T
98105
}
106+
extension Vector: Equatable where T: Equatable {}
99107
extension Vector: Differentiable where T: Differentiable {}
108+
extension Vector: AdditiveArithmetic where T: AdditiveArithmetic {}
100109

101110
// expected-note @+1 {{where 'T' = 'Int'}}
102111
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
103112

113+
// expected-note @+1 {{where 'T' = 'Int'}}
114+
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Vector<T>) -> Vector<U>) {}
115+
104116
func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
105117
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
106118
inferredConformancesGeneric(nondiffVectorFunc)
119+
// expected-error @+1 {{global function 'inferredConformancesGenericLinear' requires that 'Int' conform to 'Differentiable}}
120+
inferredConformancesGenericLinear(nondiffVectorFunc)
107121

108122
func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
109123
inferredConformancesGeneric(diffVectorFunc) // okay!
110124

111125
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}
126+
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Vector<T>) -> Vector<U> {}

0 commit comments

Comments
 (0)