Skip to content

Commit 35edb18

Browse files
committed
[AutoDiff upstream] Add the _Differentiation module.
The `_Differentiation` module is the experimental support library for differentiable programming. It is built when the build-script flag `--enable-experimental-differentiable-programming` is enabled. The `Differentiable` protocol generalizes all types that work with differentiation. It is a core piece of the differentiable programming project. Other parts depending on the `Differentiable` protocol will be upstreamed piece by piece. The `Differentiable` protocol is compiler-known and will be used during type-checking, SILGen, and the SIL differentiation transform.
1 parent 2bd55f6 commit 35edb18

File tree

9 files changed

+108
-8
lines changed

9 files changed

+108
-8
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ IDENTIFIER(decode)
5353
IDENTIFIER(decodeIfPresent)
5454
IDENTIFIER(Decoder)
5555
IDENTIFIER(decoder)
56+
IDENTIFIER_WITH_NAME(Differentiation, "_Differentiation")
5657
IDENTIFIER(dynamicallyCall)
5758
IDENTIFIER(dynamicMember)
5859
IDENTIFIER(Element)

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ PROTOCOL_(DestructorSafeContainer)
8383

8484
PROTOCOL(StringInterpolationProtocol)
8585

86+
PROTOCOL(Differentiable)
87+
8688
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
8789
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
8890
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)

lib/AST/ASTContext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,9 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
899899
case KnownProtocolKind::CFObject:
900900
M = getLoadedModule(Id_CoreFoundation);
901901
break;
902+
case KnownProtocolKind::Differentiable:
903+
M = getLoadedModule(Id_Differentiation);
904+
break;
902905
default:
903906
M = getStdlibModule();
904907
break;

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4217,6 +4217,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42174217
case KnownProtocolKind::Encodable:
42184218
case KnownProtocolKind::Decodable:
42194219
case KnownProtocolKind::StringInterpolationProtocol:
4220+
case KnownProtocolKind::Differentiable:
42204221
return SpecialProtocol::None;
42214222
}
42224223

stdlib/public/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ if(SWIFT_BUILD_STDLIB)
6363
add_subdirectory(SwiftOnoneSupport)
6464
endif()
6565

66+
# Build differentiable programming support library only if enabled.
67+
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING AND SWIFT_BUILD_STDLIB)
68+
message(STATUS "Building Swift differentiable programming support library.")
69+
add_subdirectory(Differentiation)
70+
endif()
71+
6672
if(SWIFT_BUILD_STDLIB OR SWIFT_BUILD_REMOTE_MIRROR)
6773
add_subdirectory(Reflection)
6874
add_subdirectory(SwiftRemoteMirror)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#===--- CMakeLists.txt - Differentiable programming support library ------===#
2+
#
3+
# This source file is part of the Swift.org open source project
4+
#
5+
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
# Licensed under Apache License v2.0 with Runtime Library Exception
7+
#
8+
# See https://swift.org/LICENSE.txt for license information
9+
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
#
11+
#===----------------------------------------------------------------------===#
12+
13+
add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
14+
Differentiable.swift
15+
16+
SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
17+
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
18+
INSTALL_IN_COMPONENT stdlib)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===--- Differentiable.swift ---------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines the Differentiable protocol, used by the experimental
14+
// differentiable programming project. This API is not stable and subject to
15+
// change.
16+
//
17+
// Please see forum discussion for more information about the differentiable
18+
// programming project:
19+
// https://forums.swift.org/t/differentiable-programming-mega-proposal/28547
20+
//
21+
//===----------------------------------------------------------------------===//
22+
23+
/// A type that mathematically represents a differentiable manifold whose
24+
/// tangent spaces are finite-dimensional.
25+
public protocol Differentiable {
26+
/// A type representing a differentiable value's derivatives.
27+
///
28+
/// Mathematically, this is equivalent to the tangent bundle of the
29+
/// differentiable manifold represented by the differentiable type.
30+
associatedtype TangentVector: Differentiable & AdditiveArithmetic
31+
where TangentVector.TangentVector == TangentVector
32+
33+
/// Moves `self` along the given direction. In Riemannian geometry, this is
34+
/// equivalent to exponential map, which moves `self` on the geodesic surface
35+
/// along the given tangent vector.
36+
mutating func move(along direction: TangentVector)
37+
}
38+
39+
public extension Differentiable where TangentVector == Self {
40+
@_alwaysEmitIntoClient
41+
mutating func move(along direction: TangentVector) {
42+
self += direction
43+
}
44+
}

stdlib/public/core/CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,6 @@ set(SWIFTLIB_ESSENTIAL_GYB_SOURCES
196196
UnsafeRawBufferPointer.swift.gyb
197197
)
198198

199-
# Compile differentiable programming sources only if enabled.
200-
set(SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES)
201-
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
202-
# TODO: Add `_Differentiable` protocol.
203-
message(STATUS "Differentiable programming standard library additions enabled.")
204-
endif()
205-
206199
# The complete list of sources in the core standard library. Includes
207200
# all the essential sources listed above.
208201
set(SWIFTLIB_SOURCES
@@ -221,7 +214,6 @@ set(SWIFTLIB_SOURCES
221214
VarArgs.swift
222215
Zip.swift
223216
"${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c"
224-
${SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES}
225217
)
226218

227219
set(SWIFTLIB_GYB_SOURCES
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-typecheck-verify-swift
2+
3+
import _Differentiation
4+
5+
// Test conformances.
6+
7+
struct Wrapper<T> {
8+
var value: T
9+
}
10+
extension Wrapper: Equatable where T: Equatable {}
11+
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
12+
static var zero: Self {
13+
Wrapper(value: T.zero)
14+
}
15+
static func + (lhs: Self, rhs: Self) -> Self {
16+
return Wrapper(value: lhs.value + rhs.value)
17+
}
18+
static func - (lhs: Self, rhs: Self) -> Self {
19+
return Wrapper(value: lhs.value + rhs.value)
20+
}
21+
}
22+
extension Wrapper: Differentiable where T: Differentiable {
23+
typealias TangentVector = Wrapper<T.TangentVector>
24+
mutating func move(along direction: TangentVector) {
25+
value.move(along: direction.value)
26+
}
27+
}
28+
29+
// Test conformances for standard library types.
30+
31+
extension Float: Differentiable {
32+
public typealias TangentVector = Self
33+
}

0 commit comments

Comments
 (0)