Skip to content

[AutoDiff] [Sema] Infer Differentiable conformance for @differentiable function parameters and results. #24896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions lib/AST/GenericSignatureBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5311,6 +5311,23 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
return Action::Continue;
}

// SWIFT_ENABLE_TENSORFLOW
if (auto *fnTy = ty->getAs<AnyFunctionType>()) {
if (fnTy->getExtInfo().isDifferentiable()) {
auto *diffableProto = Builder.getASTContext()
.getProtocol(KnownProtocolKind::Differentiable);
auto constrainToDifferentiable = [&](Type typeToConstrain) {
Requirement req(RequirementKind::Conformance, typeToConstrain,
diffableProto->getDeclaredType());
Builder.addRequirement(req, source, nullptr);
};
for (auto &param : fnTy->getParams())
if (!param.isNonDifferentiable())
constrainToDifferentiable(param.getPlainType());
constrainToDifferentiable(fnTy->getResult());
}
}

if (!ty->isSpecialized())
return Action::Continue;

Expand Down
76 changes: 26 additions & 50 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -213,29 +213,29 @@ public extension Differentiable {

public extension Differentiable {
@inlinable
func valueWithPullback<R : Differentiable>(
func valueWithPullback<R>(
in f: @differentiable (Self) -> R
) -> (value: R, pullback: (R.TangentVector) -> TangentVector) {
return Builtin.autodiffApply_vjp_arity1(f, self)
}

@inlinable
func pullback<R : Differentiable>(
func pullback<R>(
in f: @differentiable (Self) -> R
) -> (R.TangentVector) -> TangentVector {
return Builtin.autodiffApply_vjp_arity1(f, self).1
}

@inlinable
func gradient<R : Differentiable>(
func gradient<R>(
in f: @differentiable (Self) -> R
) -> TangentVector
where R : FloatingPoint, R.TangentVector == R {
return self.pullback(in: f)(R(1))
}

@inlinable
func valueWithGradient<R : Differentiable>(
func valueWithGradient<R>(
in f: @differentiable (Self) -> R
) -> (value: R, gradient: TangentVector)
where R : FloatingPoint, R.TangentVector == R {
Expand All @@ -244,30 +244,30 @@ public extension Differentiable {
}

@inlinable
func valueWithPullback<T : Differentiable, R : Differentiable>(
func valueWithPullback<T, R>(
at x: T, in f: @differentiable (Self, T) -> R
) -> (value: R,
pullback: (R.TangentVector) -> (TangentVector, T.TangentVector)) {
return Builtin.autodiffApply_vjp_arity2(f, self, x)
}

@inlinable
func pullback<T : Differentiable, R : Differentiable>(
func pullback<T, R>(
at x: T, in f: @differentiable (Self, T) -> R
) -> (R.TangentVector) -> (TangentVector, T.TangentVector) {
return Builtin.autodiffApply_vjp_arity2(f, self, x).1
}

@inlinable
func gradient<T : Differentiable, R : Differentiable>(
func gradient<T, R>(
at x: T, in f: @differentiable (Self, T) -> R
) -> (TangentVector, T.TangentVector)
where R : FloatingPoint, R.TangentVector == R {
return self.pullback(at: x, in: f)(R(1))
}

@inlinable
func valueWithGradient<T : Differentiable, R : Differentiable>(
func valueWithGradient<T, R>(
at x: T, in f: @differentiable (Self, T) -> R
) -> (value: R, gradient: (TangentVector, T.TangentVector))
where R : FloatingPoint, R.TangentVector == R {
Expand All @@ -285,17 +285,15 @@ public extension Differentiable {
@inlinable
public func valueWithPullback<T, R>(
at x: T, in f: @differentiable (T) -> R
) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
where T : Differentiable, R : Differentiable {
) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
return Builtin.autodiffApply_vjp(f, x)
}

@inlinable
public func valueWithPullback<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> R
) -> (value: R,
pullback: (R.TangentVector) -> (T.TangentVector, U.TangentVector))
where T : Differentiable, U : Differentiable, R : Differentiable {
pullback: (R.TangentVector) -> (T.TangentVector, U.TangentVector)) {
return Builtin.autodiffApply_vjp_arity2(f, x, y)
}

Expand All @@ -304,9 +302,7 @@ public func valueWithPullback<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
) -> (value: R,
pullback: (R.TangentVector)
-> (T.TangentVector, U.TangentVector, V.TangentVector))
where T : Differentiable, U : Differentiable, V : Differentiable,
R : Differentiable {
-> (T.TangentVector, U.TangentVector, V.TangentVector)) {
return Builtin.autodiffApply_vjp_arity3(f, x, y, z)
}

Expand All @@ -315,26 +311,22 @@ public func valueWithPullback<T, U, V, R>(
@inlinable
public func pullback<T, R>(
at x: T, in f: @differentiable (T) -> R
) -> (R.TangentVector) -> T.TangentVector
where T : Differentiable, R : Differentiable {
) -> (R.TangentVector) -> T.TangentVector {
return Builtin.autodiffApply_vjp(f, x).1
}

@inlinable
public func pullback<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> R
) -> (R.TangentVector) -> (T.TangentVector, U.TangentVector)
where T : Differentiable, U : Differentiable, R : Differentiable {
) -> (R.TangentVector) -> (T.TangentVector, U.TangentVector) {
return Builtin.autodiffApply_vjp_arity2(f, x, y).1
}

@inlinable
public func pullback<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
) -> (R.TangentVector)
-> (T.TangentVector, U.TangentVector, V.TangentVector)
where T : Differentiable, U : Differentiable, V : Differentiable,
R : Differentiable {
-> (T.TangentVector, U.TangentVector, V.TangentVector) {
return Builtin.autodiffApply_vjp_arity3(f, x, y, z).1
}

Expand All @@ -344,8 +336,7 @@ public func pullback<T, U, V, R>(
public func valueWithGradient<T, R>(
at x: T, in f: @differentiable (T) -> R
) -> (value: R, gradient: T.TangentVector)
where T : Differentiable, R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
let (y, pullback) = valueWithPullback(at: x, in: f)
return (y, pullback(R(1)))
}
Expand All @@ -354,8 +345,7 @@ public func valueWithGradient<T, R>(
public func valueWithGradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> R
) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
where T : Differentiable, U : Differentiable,
R : FloatingPoint & Differentiable, R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
let (y, pullback) = valueWithPullback(at: x, y, in: f)
return (y, pullback(R(1)))
}
Expand All @@ -365,8 +355,7 @@ public func valueWithGradient<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
) -> (value: R,
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
where T : Differentiable, U : Differentiable, V : Differentiable,
R : FloatingPoint & Differentiable, R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
return (y, pullback(R(1)))
}
Expand All @@ -377,18 +366,15 @@ public func valueWithGradient<T, U, V, R>(
public func valueWithGradient<T, R>(
of f: @escaping @differentiable (T) -> R
) -> (T) -> (value: R, gradient: T.TangentVector)
where T : Differentiable, R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x in valueWithGradient(at: x, in: f) }
}

@inlinable
public func valueWithGradient<T, U, R>(
of f: @escaping @differentiable (T, U) -> R
) -> (T, U) -> (value: R, gradient: (T.TangentVector, U.TangentVector))
where T : Differentiable, U : Differentiable,
R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x, y in valueWithGradient(at: x, y, in: f) }
}

Expand All @@ -398,9 +384,7 @@ public func valueWithGradient<T, U, V, R>(
) -> (T, U, V)
-> (value: R,
gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
where T : Differentiable, U : Differentiable, V : Differentiable,
R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x, y, z in valueWithGradient(at: x, y, z, in: f) }
}

Expand All @@ -410,26 +394,23 @@ public func valueWithGradient<T, U, V, R>(
public func gradient<T, R>(
at x: T, in f: @differentiable (T) -> R
) -> T.TangentVector
where T : Differentiable, R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return pullback(at: x, in: f)(R(1))
}

@inlinable
public func gradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> R
) -> (T.TangentVector, U.TangentVector)
where T : Differentiable, U : Differentiable,
R : FloatingPoint & Differentiable, R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return pullback(at: x, y, in: f)(R(1))
}

@inlinable
public func gradient<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
) -> (T.TangentVector, U.TangentVector, V.TangentVector)
where T : Differentiable, U : Differentiable, V : Differentiable,
R : FloatingPoint & Differentiable, R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return pullback(at: x, y, z, in: f)(R(1))
}

Expand All @@ -439,28 +420,23 @@ public func gradient<T, U, V, R>(
public func gradient<T, R>(
of f: @escaping @differentiable (T) -> R
) -> (T) -> T.TangentVector
where T : Differentiable, R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x in gradient(at: x, in: f) }
}

@inlinable
public func gradient<T, U, R>(
of f: @escaping @differentiable (T, U) -> R
) -> (T, U) -> (T.TangentVector, U.TangentVector)
where T : Differentiable, U : Differentiable,
R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x, y in gradient(at: x, y, in: f) }
}

@inlinable
public func gradient<T, U, V, R>(
of f: @escaping @differentiable (T, U, V) -> R
) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector)
where T : Differentiable, U : Differentiable, V : Differentiable,
R : FloatingPoint & Differentiable,
R.TangentVector == R {
where R : FloatingPoint, R.TangentVector == R {
return { x, y, z in gradient(at: x, y, z, in: f) }
}

Expand Down
28 changes: 28 additions & 0 deletions test/AutoDiff/differentiable_func_type_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,31 @@ func test2<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -
func test3<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Int) {}
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
func test4<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Int) {}

let diffFunc: @differentiable (Float) -> Float
func inferredConformances<T, U>(_: @differentiable (T) -> U) {}
inferredConformances(diffFunc)

func inferredConformancesResult<T, U>() -> @differentiable (T) -> U {}

let diffFuncWithNondiff: @differentiable (Float, @nondiff Int) -> Float
func inferredConformances<T, U, V>(_: @differentiable (T, @nondiff U) -> V) {}
inferredConformances(diffFuncWithNondiff)

struct Vector<T> {
var x, y: T
}
extension Vector: Differentiable where T: Differentiable {}

// expected-note @+2 {{where 'T' = 'Int'}}
// expected-note @+1 {{where 'U' = 'Int'}}
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}

let nondiffVectorFunc: (Vector<Int>) -> Vector<Int>
// expected-error @+1 2 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
inferredConformancesGeneric(nondiffVectorFunc)

let diffVectorFunc: (Vector<Float>) -> Vector<Float>
inferredConformancesGeneric(diffVectorFunc) // okay!

func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}