Skip to content

Commit f2b6365

Browse files
committed
[AutoDiff upstream] Add the _Differentiation module.
The `_Differentiation` module is the experimental support library for differentiable programming. It is built when the build-script flag `--enable-experimental-differentiable-programming` is enabled. The `Differentiable` protocol generalizes all types that work with differentiation. It is a core piece of the differentiable programming project. Other parts depending on the `Differentiable` protocol will be upstreamed piece by piece. The `Differentiable` protocol is compiler-known and will be used during type-checking, SILGen, and the SIL differentiation transform.
1 parent 8afd60c commit f2b6365

File tree

12 files changed

+133
-8
lines changed

12 files changed

+133
-8
lines changed

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ IDENTIFIER(decode)
5454
IDENTIFIER(decodeIfPresent)
5555
IDENTIFIER(Decoder)
5656
IDENTIFIER(decoder)
57+
IDENTIFIER_(Differentiation)
5758
IDENTIFIER(dynamicallyCall)
5859
IDENTIFIER(dynamicMember)
5960
IDENTIFIER(Element)

include/swift/AST/KnownProtocols.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ PROTOCOL_(DestructorSafeContainer)
8383

8484
PROTOCOL(StringInterpolationProtocol)
8585

86+
PROTOCOL(Differentiable)
87+
8688
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
8789
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
8890
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)

lib/AST/ASTContext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,9 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
898898
case KnownProtocolKind::CFObject:
899899
M = getLoadedModule(Id_CoreFoundation);
900900
break;
901+
case KnownProtocolKind::Differentiable:
902+
M = getLoadedModule(Id_Differentiation);
903+
break;
901904
default:
902905
M = getStdlibModule();
903906
break;

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4217,6 +4217,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42174217
case KnownProtocolKind::Encodable:
42184218
case KnownProtocolKind::Decodable:
42194219
case KnownProtocolKind::StringInterpolationProtocol:
4220+
case KnownProtocolKind::Differentiable:
42204221
return SpecialProtocol::None;
42214222
}
42224223

stdlib/public/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ if(SWIFT_BUILD_STDLIB)
6363
add_subdirectory(SwiftOnoneSupport)
6464
endif()
6565

66+
# Build differentiable programming support library only if enabled.
67+
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING AND SWIFT_BUILD_STDLIB)
68+
message(STATUS "Building Swift differentiable programming support library.")
69+
add_subdirectory(Differentiation)
70+
endif()
71+
6672
if(SWIFT_BUILD_STDLIB OR SWIFT_BUILD_REMOTE_MIRROR)
6773
add_subdirectory(Reflection)
6874
add_subdirectory(SwiftRemoteMirror)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#===--- CMakeLists.txt - Differentiable programming support library ------===#
2+
#
3+
# This source file is part of the Swift.org open source project
4+
#
5+
# Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
# Licensed under Apache License v2.0 with Runtime Library Exception
7+
#
8+
# See https://swift.org/LICENSE.txt for license information
9+
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
#
11+
#===----------------------------------------------------------------------===#
12+
13+
add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
14+
Differentiable.swift
15+
16+
SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
17+
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
18+
INSTALL_IN_COMPONENT stdlib)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===--- Differentiable.swift ---------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines the Differentiable protocol, used by the experimental
14+
// differentiable programming project. This API is not stable and subject to
15+
// change.
16+
//
17+
// Please see forum discussion for more information about the differentiable
18+
// programming project:
19+
// https://forums.swift.org/t/differentiable-programming-mega-proposal/28547
20+
//
21+
//===----------------------------------------------------------------------===//
22+
23+
/// A type that mathematically represents a differentiable manifold whose
24+
/// tangent spaces are finite-dimensional.
25+
public protocol Differentiable {
26+
/// A type representing a differentiable value's derivatives.
27+
///
28+
/// Mathematically, this is equivalent to the tangent bundle of the
29+
/// differentiable manifold represented by the differentiable type.
30+
associatedtype TangentVector: Differentiable & AdditiveArithmetic
31+
where TangentVector.TangentVector == TangentVector
32+
33+
/// Moves `self` along the given direction. In Riemannian geometry, this is
34+
/// equivalent to exponential map, which moves `self` on the geodesic surface
35+
/// along the given tangent vector.
36+
mutating func move(along direction: TangentVector)
37+
}
38+
39+
public extension Differentiable where TangentVector == Self {
40+
@_alwaysEmitIntoClient
41+
mutating func move(along direction: TangentVector) {
42+
self += direction
43+
}
44+
}

stdlib/public/core/CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,6 @@ set(SWIFTLIB_ESSENTIAL_GYB_SOURCES
196196
UnsafeRawBufferPointer.swift.gyb
197197
)
198198

199-
# Compile differentiable programming sources only if enabled.
200-
set(SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES)
201-
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
202-
# TODO: Add `_Differentiable` protocol.
203-
message(STATUS "Differentiable programming standard library additions enabled.")
204-
endif()
205-
206199
# The complete list of sources in the core standard library. Includes
207200
# all the essential sources listed above.
208201
set(SWIFTLIB_SOURCES
@@ -221,7 +214,6 @@ set(SWIFTLIB_SOURCES
221214
VarArgs.swift
222215
Zip.swift
223216
"${SWIFT_SOURCE_DIR}/stdlib/linker-support/magic-symbols-for-install-name.c"
224-
${SWIFTLIB_DIFFERENTIABLE_PROGRAMMING_SOURCES}
225217
)
226218

227219
set(SWIFTLIB_GYB_SOURCES
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: %target-typecheck-verify-swift
2+
// REQUIRES: differentiable_programming
3+
4+
import _Differentiation
5+
6+
// Test conformances.
7+
8+
struct FloatWrapper {
9+
var value: Float
10+
}
11+
extension FloatWrapper: AdditiveArithmetic {
12+
static var zero: Self {
13+
FloatWrapper(value: Float.zero)
14+
}
15+
static func + (lhs: Self, rhs: Self) -> Self {
16+
return FloatWrapper(value: lhs.value + rhs.value)
17+
}
18+
static func - (lhs: Self, rhs: Self) -> Self {
19+
return FloatWrapper(value: lhs.value + rhs.value)
20+
}
21+
}
22+
extension FloatWrapper: Differentiable {
23+
public typealias TangentVector = Self
24+
}
25+
26+
struct Wrapper<T> {
27+
var value: T
28+
}
29+
extension Wrapper: Equatable where T: Equatable {}
30+
extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic {
31+
static var zero: Self {
32+
Wrapper(value: T.zero)
33+
}
34+
static func + (lhs: Self, rhs: Self) -> Self {
35+
return Wrapper(value: lhs.value + rhs.value)
36+
}
37+
static func - (lhs: Self, rhs: Self) -> Self {
38+
return Wrapper(value: lhs.value + rhs.value)
39+
}
40+
}
41+
extension Wrapper: Differentiable where T: Differentiable {
42+
typealias TangentVector = Wrapper<T.TangentVector>
43+
mutating func move(along direction: TangentVector) {
44+
value.move(along: direction.value)
45+
}
46+
}

test/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ normalize_boolean_spelling(SWIFT_AST_VERIFIER)
134134
normalize_boolean_spelling(SWIFT_ASAN_BUILD)
135135
normalize_boolean_spelling(SWIFT_BUILD_SYNTAXPARSERLIB)
136136
normalize_boolean_spelling(SWIFT_ENABLE_SOURCEKIT_TESTS)
137+
normalize_boolean_spelling(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
137138
is_build_type_optimized("${SWIFT_STDLIB_BUILD_TYPE}" SWIFT_OPTIMIZED)
138139

139140
set(profdata_merge_worker
@@ -305,6 +306,10 @@ _Block_release(void) { }\n")
305306

306307
list(APPEND LIT_ARGS "--xunit-xml-output=${SWIFT_TEST_RESULTS_DIR}/lit-tests.xml")
307308

309+
if(SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING)
310+
list(APPEND LIT_ARGS "--param" "differentiable_programming")
311+
endif()
312+
308313
foreach(test_subset ${TEST_SUBSETS})
309314
set(directories)
310315
set(dependencies ${test_dependencies})

test/lit.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,10 @@ swift_version = lit_config.params.get('swift-version',
345345
lit_config.note('Compiling with -swift-version ' + swift_version)
346346
config.swift_test_options = '-swift-version ' + swift_version
347347

348+
differentiable_programming = lit_config.params.get('differentiable_programming', None)
349+
if differentiable_programming is not None:
350+
config.available_features.add('differentiable_programming')
351+
348352
test_options = os.environ.get('SWIFT_TEST_OPTIONS')
349353
if test_options:
350354
config.swift_test_options += ' '

test/lit.site.cfg.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ else:
121121
if '@SWIFT_INCLUDE_TOOLS@' == 'TRUE':
122122
config.available_features.add('swift_tools_extra')
123123

124+
if "@SWIFT_ENABLE_EXPERIMENTAL_DIFFERENTIABLE_PROGRAMMING@" == "TRUE":
125+
config.available_features.add('differentiable_programming')
126+
124127
# Let the main config do the real work.
125128
if config.test_exec_root is None:
126129
config.test_exec_root = os.path.dirname(os.path.realpath(__file__))

0 commit comments

Comments
 (0)