Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1b4c33a

Browse files
committed
Merge branch 'master' into complex-numbers2
2 parents b5cf071 + d7eff12 commit 1b4c33a

19 files changed

+1053
-316
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ xcuserdata
55
DerivedData/
66
*.xcodeproj
77
*~
8+
*.vscode
9+
*.idea
810

911
### MacOS ###
1012
.DS_Store

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04
44

55
# Allows the caller to specify the toolchain to use.
6-
ARG swift_tf_url=https://storage.googleapis.com/s4tf-kokoro-artifact-testing/latest/swift-tensorflow-DEVELOPMENT-cuda10.0-cudnn7-ubuntu18.04.tar.gz
6+
ARG swift_tf_url=https://storage.googleapis.com/s4tf-kokoro-artifact-testing/latest/swift-tensorflow-DEVELOPMENT-cuda10.0-cudnn7-test-ubuntu18.04.tar.gz
77

88
# Install Swift deps.
99
ENV DEBIAN_FRONTEND=noninteractive

Sources/TensorFlow/Core/DataTypes.swift

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414

1515
import CTensorFlow
1616

17-
public extension TensorDataType {
18-
var _cDataType: TF_DataType {
19-
return TF_DataType(rawValue: _internalStorageType)
17+
/// A TensorFlow dynamic type value that can be created from types that conform to
18+
/// `TensorFlowScalar`.
19+
// This simply wraps a `TF_DataType` and allows user code to handle
20+
// `TF_DataType` without importing CTensorFlow, which pollutes the namespace
21+
// with TensorFlow C API declarations.
22+
public struct TensorDataType {
23+
public var _cDataType: TF_DataType
24+
25+
@usableFromInline
26+
internal init(_ cDataType: TF_DataType) {
27+
self._cDataType = cDataType
2028
}
21-
22-
init(_ cDataType: TF_DataType) {
23-
self.init(rawValue: cDataType.rawValue)
24-
}
2529
}
2630

2731
@usableFromInline

Sources/TensorFlow/Core/DifferentialOperators.swift

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ public extension Differentiable {
2121
func gradient<R: TensorFlowFloatingPoint>(
2222
in f: @differentiable (Self) -> Tensor<R>
2323
) -> TangentVector {
24-
return self.pullback(in: f)(Tensor<R>(1))
24+
return self.valueWithGradient(in: f).1
2525
}
2626

2727
@inlinable
2828
func valueWithGradient<R: TensorFlowFloatingPoint>(
2929
in f: @differentiable (Self) -> Tensor<R>
3030
) -> (value: Tensor<R>, gradient: TangentVector) {
3131
let (y, pb) = self.valueWithPullback(in: f)
32+
precondition(y.rank == 0)
3233
return (y, pb(Tensor<R>(1)))
3334
}
3435

@@ -37,7 +38,7 @@ public extension Differentiable {
3738
at x: T,
3839
in f: @differentiable (Self, T) -> Tensor<R>
3940
) -> (TangentVector, T.TangentVector) {
40-
return self.pullback(at: x, in: f)(Tensor<R>(1))
41+
return self.valueWithGradient(at: x, in: f).1
4142
}
4243

4344
@inlinable
@@ -46,6 +47,7 @@ public extension Differentiable {
4647
in f: @differentiable (Self, T) -> Tensor<R>
4748
) -> (value: Tensor<R>, gradient: (TangentVector, T.TangentVector)) {
4849
let (y, pb) = self.valueWithPullback(at: x, in: f)
50+
precondition(y.rank == 0)
4951
return (y, pb(Tensor<R>(1)))
5052
}
5153
}
@@ -63,6 +65,7 @@ public func valueWithGradient<T, R>(
6365
) -> (value: Tensor<R>, gradient: T.TangentVector)
6466
where T: Differentiable, R: TensorFlowFloatingPoint {
6567
let (y, pullback) = valueWithPullback(at: x, in: f)
68+
precondition(y.rank == 0)
6669
return (y, pullback(Tensor<R>(1)))
6770
}
6871

@@ -74,6 +77,7 @@ public func valueWithGradient<T, U, R>(
7477
) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
7578
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
7679
let (y, pullback) = valueWithPullback(at: x, y, in: f)
80+
precondition(y.rank == 0)
7781
return (y, pullback(Tensor<R>(1)))
7882
}
7983

@@ -86,6 +90,7 @@ public func valueWithGradient<T, U, R>(
8690
// ) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
8791
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
8892
// let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
93+
// precondition(y.rank == 0)
8994
// return (y, pullback(Tensor<R>(1)))
9095
// }
9196

@@ -124,7 +129,7 @@ public func gradient<T, R>(
124129
at x: T,
125130
in f: @differentiable (T) -> Tensor<R>
126131
) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint {
127-
return pullback(at: x, in: f)(Tensor<R>(1))
132+
return valueWithGradient(at: x, in: f).1
128133
}
129134

130135
@inlinable
@@ -134,7 +139,7 @@ public func gradient<T, U, R>(
134139
in f: @differentiable (T, U) -> Tensor<R>
135140
) -> (T.TangentVector, U.TangentVector)
136141
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
137-
return pullback(at: x, y, in: f)(Tensor<R>(1))
142+
return valueWithGradient(at: x, y, in: f).1
138143
}
139144

140145
// @inlinable
@@ -145,7 +150,7 @@ public func gradient<T, U, R>(
145150
// in f: @differentiable (T, U, V) -> Tensor<R>
146151
// ) -> (T.TangentVector, U.TangentVector, V.TangentVector)
147152
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
148-
// return pullback(at: x, y, z, in: f)(Tensor<R>(1))
153+
// return valueWithGradient(at: x, y, z, in: f).1
149154
// }
150155

151156
// Gradient (curried)

Sources/TensorFlow/Core/Runtime.swift

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ public final class _ExecutionContext {
551551
@usableFromInline let eagerContext: CTFEContext
552552

553553
/// The status for checking TensorFlow errors.
554-
private let status: CTFStatus = TF_NewStatus()
554+
@usableFromInline let status: CTFStatus = TF_NewStatus()
555555

556556
/// The mutex for preventing potential concurrent access.
557557
private var mutex: pthread_mutex_t = pthread_mutex_t()
@@ -569,6 +569,12 @@ public final class _ExecutionContext {
569569
// Initialize the TF runtime exactly once. Only affects local execution
570570
// (when _RuntimeConfig.tensorFlowServer is set to "").
571571
if !_RuntimeConfig.tensorFlowRuntimeInitialized {
572+
// Install a signal handler to ensure we exit when interrupted.
573+
signal(SIGINT) { _ in
574+
print("Caught interrupt signal, exiting...")
575+
exit(1)
576+
}
577+
572578
var args = ["dummyProgramName"]
573579
if _RuntimeConfig.printsDebugLog {
574580
args.append("--alsologtostderr")
@@ -588,24 +594,19 @@ public final class _ExecutionContext {
588594

589595
// Calculate the addresses of all the strings within our single buffer, and then call
590596
// TF_InitMain.
591-
flattenedStringBytes.withUnsafeMutableBufferPointer { flattenedStringBytesBuffer in
597+
flattenedStringBytes.withUnsafeMutableBufferPointer { buffer in
592598
var stringAddrs: [UnsafeMutablePointer<Int8>?] = []
593-
var currentStringAddr = flattenedStringBytesBuffer.baseAddress
599+
var currentStringAddr = buffer.baseAddress
594600
.map(UnsafeMutablePointer.init)
595601
for length in lengths {
596602
stringAddrs.append(currentStringAddr)
597603
currentStringAddr = currentStringAddr?.advanced(by: length)
598604
}
599605

600606
stringAddrs.withUnsafeMutableBufferPointer { stringAddrsBuffer in
601-
var cArgs = [stringAddrsBuffer.baseAddress.map(UnsafeMutablePointer.init)]
602-
var cArgsCount = [Int32(args.count)]
603-
604-
cArgs.withUnsafeMutableBufferPointer { cArgsBuffer in
605-
cArgsCount.withUnsafeMutableBufferPointer { cArgsCountBuffer in
606-
TF_InitMain(nil, cArgsCountBuffer.baseAddress, cArgsBuffer.baseAddress)
607-
}
608-
}
607+
var cArgsCount = Int32(args.count)
608+
var cArgs = stringAddrsBuffer.baseAddress.map(UnsafeMutablePointer.init)
609+
TF_InitMain(nil, &cArgsCount, &cArgs)
609610
}
610611
}
611612
_RuntimeConfig.tensorFlowRuntimeInitialized = true

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
import CTensorFlow
16+
1517
infix operator .==: ComparisonPrecedence
1618
infix operator .!=: ComparisonPrecedence
1719

@@ -50,7 +52,10 @@ public extension Tensor {
5052
var rank: Int {
5153
@_semantics("autodiff.nonvarying")
5254
get {
53-
return Int(rankTensor.scalar!)
55+
let status = _ExecutionContext.global.status
56+
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
57+
checkOk(status)
58+
return Int(rank)
5459
}
5560
}
5661

@@ -59,15 +64,25 @@ public extension Tensor {
5964
var shape: TensorShape {
6065
@_semantics("autodiff.nonvarying")
6166
get {
62-
return TensorShape(shapeTensor.scalars.map(Int.init))
67+
let status = _ExecutionContext.global.status
68+
let dims: [Int] = (0..<Int32(rank)).map { i in
69+
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)
70+
checkOk(status)
71+
return Int(dim)
72+
}
73+
return TensorShape(dims)
6374
}
6475
}
6576

6677
/// The number of scalars in the `Tensor`.
6778
@inlinable
6879
var scalarCount: Int {
80+
@_semantics("autodiff.nonvarying")
6981
get {
70-
return Int(scalarCountTensor.scalar!)
82+
let status = _ExecutionContext.global.status
83+
let size = TFE_TensorHandleNumElements(handle._cTensorHandle, status)
84+
checkOk(status)
85+
return Int(size)
7186
}
7287
}
7388

@@ -511,20 +526,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
511526
lhs: Tensor,
512527
rhs: Tensor
513528
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
514-
return (lhs + rhs, { [
515-
lhsShape = lhs.shape,
516-
rhsShape = rhs.shape,
517-
lhsShapeTensor = lhs.shapeTensor,
518-
rhsShapeTensor = rhs.shapeTensor] v in
519-
var lhsGrad = v
520-
var rhsGrad = v
521-
if lhsGrad.shape != lhsShape {
522-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
523-
}
524-
if rhsGrad.shape != rhsShape {
525-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
526-
}
527-
return (lhsGrad, rhsGrad)
529+
return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
530+
let lhsGrad = v
531+
let rhsGrad = lhsGrad
532+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
533+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
534+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
528535
})
529536
}
530537

@@ -533,20 +540,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
533540
lhs: Tensor,
534541
rhs: Tensor
535542
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
536-
return (lhs - rhs, { [
537-
lhsShape = lhs.shape,
538-
rhsShape = rhs.shape,
539-
lhsShapeTensor = lhs.shapeTensor,
540-
rhsShapeTensor = rhs.shapeTensor] v in
541-
var lhsGrad = v
542-
var rhsGrad = -v
543-
if lhsGrad.shape != lhsShape {
544-
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
545-
}
546-
if rhsGrad.shape != rhsShape {
547-
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
548-
}
549-
return (lhsGrad, rhsGrad)
543+
return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
544+
let lhsGrad = v
545+
let rhsGrad = -lhsGrad
546+
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
547+
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
548+
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
550549
})
551550
}
552551
}

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,45 @@
1414

1515
import CTensorFlow
1616

17+
/// A protocol representing types that can be mapped to `Array<CTensorHandle>`.
18+
///
19+
/// This protocol is defined separately from `TensorGroup` in order for the number of tensors to be
20+
/// determined at runtime. For example, `[Tensor<Float>]` may have an unknown number of elements at
21+
/// compile time.
22+
///
23+
/// This protocol can be derived automatically for structs whose stored properties all conform to
24+
/// the `TensorGroup` protocol. It cannot be derived automatically for structs whose properties all
25+
/// conform to `TensorArrayProtocol` due to the constructor requirement (i.e., in such cases it
26+
/// would be impossible to know how to break down `count` among the stored properties).
27+
public protocol TensorArrayProtocol {
28+
/// Writes the tensor handles to `address`, which must be allocated with enough capacity to hold
29+
/// `_tensorHandleCount` handles. The tensor handles written to `address` are borrowed: this
30+
/// container still owns them.
31+
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)
32+
33+
var _tensorHandleCount: Int32 { get }
34+
var _typeList: [TensorDataType] { get }
35+
36+
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
37+
}
38+
39+
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
40+
///
41+
/// When a `TensorGroup` is used as an argument to a tensor operation, it is passed as an argument
42+
/// list whose elements are the tensor fields of the type.
43+
///
44+
/// When a `TensorGroup` is returned as a result of a tensor operation, it is initialized with its
45+
/// tensor fields set to the tensor operation's tensor results.
46+
public protocol TensorGroup: TensorArrayProtocol {
47+
48+
/// The types of the tensor stored properties in this type.
49+
static var _typeList: [TensorDataType] { get }
50+
51+
/// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors
52+
/// starting at address `tensorHandles`.
53+
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?)
54+
}
55+
1756
public extension TensorGroup {
1857
/// The number of tensor fields in this type.
1958
static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }

0 commit comments

Comments
 (0)