Skip to content

[AutoDiff] Move stdlib sources to stdlib/public/core/Differentiation. #28089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ IDENTIFIER(decode)
IDENTIFIER(decodeIfPresent)
IDENTIFIER(Decoder)
IDENTIFIER(decoder)
IDENTIFIER_(Differentiation)
IDENTIFIER(dynamicallyCall)
IDENTIFIER(dynamicMember)
IDENTIFIER(Element)
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ PROTOCOL_(DestructorSafeContainer)

PROTOCOL(StringInterpolationProtocol)

PROTOCOL_(Differentiable)
PROTOCOL(Differentiable)

EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
Expand Down
6 changes: 6 additions & 0 deletions stdlib/public/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ if(SWIFT_BUILD_STDLIB)
add_subdirectory(SwiftOnoneSupport)
endif()

# Build differentiable programming support library only if enabled.
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING AND SWIFT_BUILD_STDLIB)
message(STATUS "Building Swift differentiable programming support library.")
add_subdirectory(Differentiation)
endif()

# SWIFT_ENABLE_TENSORFLOW
if(SWIFT_BUILD_STDLIB)
add_subdirectory(Python)
Expand Down
28 changes: 28 additions & 0 deletions stdlib/public/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#===--- CMakeLists.txt - Differentiable programming support library ------===#
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
#
#===----------------------------------------------------------------------===#

# SWIFT_ENABLE_TENSORFLOW
# NOTE: A non-empty `_Differentiation` module is currently created only on
# master branch, not on tensorflow branch.
#
# Instead, on tensorflow branch, the differentiation-related Swift source files
# in this directory are directly built as part of swiftCore: see
# stdlib/public/core/CMakeLists.txt. The `_Differentiation` module is created
# empty to avoid `#if canImport(_Differentiation)` guards in tests.
# SWIFT_ENABLE_TENSORFLOW END

add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
Empty.swift

SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
INSTALL_IN_COMPONENT stdlib)
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,23 @@
//===----------------------------------------------------------------------===//
//
// This file defines the Differentiable protocol, used by the experimental
// differentiable programming project. Please see forum discussion for more
// information:
// differentiable programming project. This API is not stable and subject to
// change.
//
// Please see forum discussion for more information about the differentiable
// programming project:
// https://forums.swift.org/t/differentiable-programming-mega-proposal/28547
//
//===----------------------------------------------------------------------===//

/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol _Differentiable {
public protocol Differentiable {
/// A type representing a differentiable value's derivatives.
///
/// Mathematically, this is equivalent to the tangent bundle of the
/// differentiable manifold represented by the differentiable type.
associatedtype TangentVector: _Differentiable & AdditiveArithmetic
associatedtype TangentVector: Differentiable & AdditiveArithmetic
where TangentVector.TangentVector == TangentVector

/// Moves `self` along the given direction. In Riemannian geometry, this is
Expand All @@ -44,15 +47,18 @@ public protocol _Differentiable {
this property
""")
var zeroTangentVector: TangentVector { get }
// SWIFT_ENABLE_TENSORFLOW END
}

public extension _Differentiable where TangentVector == Self {
public extension Differentiable where TangentVector == Self {
@_alwaysEmitIntoClient
mutating func move(along direction: TangentVector) {
self += direction
}
}

public extension _Differentiable {
// SWIFT_ENABLE_TENSORFLOW
public extension Differentiable {
// This is a temporary solution that allows us to add `zeroTangentVector`
// without implementing derived conformances. This property is marked
// unavailable because it will produce incorrect results when tangent vectors
Expand All @@ -61,6 +67,4 @@ public extension _Differentiable {
// implementation.
var zeroTangentVector: TangentVector { .zero }
}

// SWIFT_ENABLE_TENSORFLOW
public typealias Differentiable = _Differentiable
// SWIFT_ENABLE_TENSORFLOW END
3 changes: 3 additions & 0 deletions stdlib/public/Differentiation/Empty.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// SWIFT_ENABLE_TENSORFLOW
// Empty Swift file, only for tensorflow branch.
// See explanation in stdlib/public/Differentiation/CMakeLists.txt.
9 changes: 6 additions & 3 deletions stdlib/public/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,16 @@ set(SWIFTLIB_ESSENTIAL_GYB_SOURCES
UnsafeRawBufferPointer.swift.gyb
)

# SWIFT_ENABLE_TENSORFLOW
# Compile differentiable programming sources only if enabled.
set(SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES)
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
list(APPEND SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES
Differentiable.swift
# SWIFT_ENABLE_TENSORFLOW
DifferentiationSupport.swift)
../Differentiation/Differentiable.swift
../Differentiation/DifferentiationSupport.swift)
message(STATUS "Differentiable programming standard library additions enabled.")
endif()
# SWIFT_ENABLE_TENSORFLOW END

# The complete list of sources in the core standard library. Includes
# all the essential sources listed above.
Expand All @@ -225,7 +226,9 @@ set(SWIFTLIB_SOURCES
VarArgs.swift
Zip.swift
"${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c"
# SWIFT_ENABLE_TENSORFLOW
${SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES}
# SWIFT_ENABLE_TENSORFLOW END
)

set(SWIFTLIB_GYB_SOURCES
Expand Down
9 changes: 7 additions & 2 deletions stdlib/public/core/SIMDVectorTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,17 @@ extension SIMD${n} where Scalar: BinaryFloatingPoint {
// SWIFT_ENABLE_TENSORFLOW
extension SIMD${n} : AdditiveArithmetic where Scalar : FloatingPoint {}

extension SIMD${n} : Differentiable & EuclideanDifferentiable
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
extension SIMD${n} : Differentiable
where Scalar : Differentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint {
public typealias TangentVector = SIMD${n}
}

extension SIMD${n} : EuclideanDifferentiable
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint {
}

extension SIMD${n}
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
Scalar.TangentVector : BinaryFloatingPoint {
Expand Down
14 changes: 7 additions & 7 deletions test/AutoDiff/derived_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct TestNoDerivative : EuclideanDifferentiable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, ElementaryFunctions
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
// CHECK-AST: internal var differentiableVectorView: TestNoDerivative.TangentVector { get }

Expand All @@ -57,7 +57,7 @@ struct TestPointwiseMultiplicative : Differentiable {
// CHECK-AST: var w: PointwiseMultiplicativeDummy
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, PointwiseMultiplicative
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector


Expand All @@ -70,14 +70,14 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector

struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic {
var x: T.TangentVector
}

// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : _Differentiable
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST: internal var x: T.TangentVector
// CHECK-AST: internal init(x: T.TangentVector)
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
Expand All @@ -92,7 +92,7 @@ public struct ConditionallyDifferentiable<T> {
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}

// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
// CHECK-AST: @differentiable(wrt: self where T : _Differentiable)
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
// CHECK-AST: public var x: T
// CHECK-AST: internal init(x: T)
// CHECK-AST: }
Expand Down Expand Up @@ -121,7 +121,7 @@ final class AdditiveArithmeticClass<T : AdditiveArithmetic & Differentiable> : A
}
}

// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : _Differentiable {
// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
// CHECK-AST: final internal var x: T, y: T
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
// CHECK-AST: }
8 changes: 4 additions & 4 deletions test/AutoDiff/differentiability_witness_function_inst.sil
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ bb0:
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : _Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : _Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: }

// CHECK-LABEL: sil @test_transpose_witnesses : $@convention(thin) () -> () {
Expand All @@ -86,7 +86,7 @@ bb0:
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0 1] @bar : $@convention(thin) (Float, Float, Float) -> (Float, Float)
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : _Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : _Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [transpose] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: return undef : $()
// CHECK: }
4 changes: 2 additions & 2 deletions test/AutoDiff/differentiable_attr_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public func foo_indir_ret<T: Differentiable>(_ x: Float, _ y: T) -> T {
return y
}

// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo_indir_ret__vjp_src_0_wrt_0_1] [ossa] @foo_indir_ret : $@convention(thin) <T where T : _Differentiable> (Float, @in_guaranteed T) -> @out T {
// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo_indir_ret__vjp_src_0_wrt_0_1] [ossa] @foo_indir_ret : $@convention(thin) <T where T : Differentiable> (Float, @in_guaranteed T) -> @out T {
// CHECK-SIL: bb0(%0 : $*T, %1 : $Float, %2 : $*T):

@_silgen_name("dfoo_indir_ret")
Expand Down Expand Up @@ -101,7 +101,7 @@ struct DiffComputedProp : Differentiable & AdditiveArithmetic {
// Check that `@differentiable` attribute is transferred from computed property
// storage declaration to getter accessor.

// CHECK-AST: struct DiffComputedProp : _Differentiable & AdditiveArithmetic {
// CHECK-AST: struct DiffComputedProp : AdditiveArithmetic & Differentiable {
// CHECK-AST-NEXT: var computedProp: Float { get }
// CHECK-AST: }

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ struct TF_521<T: FloatingPoint> {
self.imaginary = imaginary
}
}
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol '_Differentiable'}}
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol 'Differentiable'}}
// expected-note @+1 {{do you want to add protocol stubs}}
extension TF_521: Differentiable where T: Differentiable {
// expected-note @+1 {{possibly intended match 'TF_521<T>.TangentVector' does not conform to 'AdditiveArithmetic'}}
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_func_debuginfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
// Conclusion: mangling coverage is important.

// Minimal dummy compiler-known `Differentiable` protocol.
public protocol _Differentiable {
public protocol Differentiable {
associatedtype TangentVector
}

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_func_type_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ extension Vector: Differentiable where T: Differentiable {}
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}

func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to '_Differentiable}}
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
inferredConformancesGeneric(nondiffVectorFunc)

func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/differentiable_function_inst_lowered.sil
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ sil_stage lowered
import Swift
import Builtin

struct Large : _Differentiable {
struct Large : Differentiable {
@_hasStorage @noDerivative let a: Float { get }
@_hasStorage @noDerivative let b: Float { get }
@_hasStorage @noDerivative let c: Float { get }
@_hasStorage @noDerivative let d: Float { get }
@_hasStorage @noDerivative let e: Float { get }
init(a: Float, b: Float, c: Float, d: Float, e: Float)
struct TangentVector : _Differentiable, AdditiveArithmetic {
struct TangentVector : Differentiable, AdditiveArithmetic {
init()
typealias TangentVector = Large.TangentVector
static var zero: Large.TangentVector { get }
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/differentiable_function_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) {
// CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [parameters 0] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with_derivative {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
46 changes: 46 additions & 0 deletions test/AutoDiff/differentiable_protocol.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: %target-typecheck-verify-swift
// REQUIRES: differentiable_programming

import _Differentiation

// Test conformances.

struct FloatWrapper {
var value: Float
}
extension FloatWrapper: AdditiveArithmetic {
static var zero: Self {
FloatWrapper(value: Float.zero)
}
static func + (lhs: Self, rhs: Self) -> Self {
return FloatWrapper(value: lhs.value + rhs.value)
}
static func - (lhs: Self, rhs: Self) -> Self {
return FloatWrapper(value: lhs.value + rhs.value)
}
}
extension FloatWrapper: Differentiable {
public typealias TangentVector = Self
}

struct Wrapper<T> {
var value: T
}
extension Wrapper: Equatable where T: Equatable {}
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
static var zero: Self {
Wrapper(value: T.zero)
}
static func + (lhs: Self, rhs: Self) -> Self {
return Wrapper(value: lhs.value + rhs.value)
}
static func - (lhs: Self, rhs: Self) -> Self {
return Wrapper(value: lhs.value + rhs.value)
}
}
extension Wrapper: Differentiable where T: Differentiable {
typealias TangentVector = Wrapper<T.TangentVector>
mutating func move(along direction: TangentVector) {
value.move(along: direction.value)
}
}
Loading