Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 62b15ee

Browse files
committed
Merge branch 'master' of https://github.com/tensorflow/swift-apis into conv
2 parents e787685 + e909225 commit 62b15ee

17 files changed

+329
-156
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct Model: Layer {
2929
var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3, activation: identity)
3030

3131
@differentiable
32-
func call(_ input: Tensor<Float>) -> Tensor<Float> {
32+
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
3333
return input.sequenced(through: layer1, layer2, layer3)
3434
}
3535
}

Sources/TensorFlow/Core/TensorGroup.swift

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol {
3232

3333
var _tensorHandleCount: Int32 { get }
3434
var _typeList: [TensorDataType] { get }
35+
var _tensorHandles: [_AnyTensorHandle] { get }
3536

3637
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
38+
init<C: RandomAccessCollection>(_handles: C) where C.Element: _AnyTensorHandle
3739
}
3840

3941
/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
@@ -88,13 +90,22 @@ extension TensorHandle: TensorGroup {
8890
return [Scalar.tensorFlowDataType]
8991
}
9092

93+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
94+
9195
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
9296
address!.initialize(to: _cTensorHandle)
9397
}
9498

9599
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
96100
self.init(_owning: tensorHandles!.pointee)
97101
}
102+
103+
public init<C: RandomAccessCollection>(
104+
_handles: C
105+
) where C.Element: _AnyTensorHandle {
106+
precondition(_handles.count == 1)
107+
self.init(handle: _handles[_handles.startIndex])
108+
}
98109
}
99110

100111
extension ResourceHandle: TensorGroup {
@@ -108,13 +119,22 @@ extension ResourceHandle: TensorGroup {
108119
return [TensorDataType(TF_RESOURCE)]
109120
}
110121

122+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
123+
111124
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
112125
address!.initialize(to: _cTensorHandle)
113126
}
114127

115128
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
116129
self.init(owning: tensorHandles!.pointee)
117130
}
131+
132+
public init<C: RandomAccessCollection>(
133+
_handles: C
134+
) where C.Element: _AnyTensorHandle {
135+
precondition(_handles.count == 1)
136+
self.init(handle: _handles[_handles.startIndex])
137+
}
118138
}
119139

120140
extension VariantHandle: TensorGroup {
@@ -128,13 +148,22 @@ extension VariantHandle: TensorGroup {
128148
return [TensorDataType(TF_VARIANT)]
129149
}
130150

151+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }
152+
131153
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
132154
address!.initialize(to: _cTensorHandle)
133155
}
134156

135157
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
136158
self.init(owning: tensorHandles!.pointee)
137159
}
160+
161+
public init<C: RandomAccessCollection>(
162+
_handles: C
163+
) where C.Element: _AnyTensorHandle {
164+
precondition(_handles.count == 1)
165+
self.init(handle: _handles[_handles.startIndex])
166+
}
138167
}
139168

140169
extension Tensor: TensorGroup {
@@ -152,9 +181,18 @@ extension Tensor: TensorGroup {
152181
address!.initialize(to: handle._cTensorHandle)
153182
}
154183

184+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
185+
155186
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
156187
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
157188
}
189+
190+
public init<C: RandomAccessCollection>(
191+
_handles: C
192+
) where C.Element: _AnyTensorHandle {
193+
precondition(_handles.count == 1)
194+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
195+
}
158196
}
159197

160198
extension _TensorElementLiteral: TensorGroup {
@@ -168,13 +206,22 @@ extension _TensorElementLiteral: TensorGroup {
168206
return [Scalar.tensorFlowDataType]
169207
}
170208

209+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
210+
171211
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
172212
address!.initialize(to: handle._cTensorHandle)
173213
}
174214

175215
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
176216
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
177217
}
218+
219+
public init<C: RandomAccessCollection>(
220+
_handles: C
221+
) where C.Element: _AnyTensorHandle {
222+
precondition(_handles.count == 1)
223+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
224+
}
178225
}
179226

180227
extension StringTensor: TensorGroup {
@@ -192,9 +239,18 @@ extension StringTensor: TensorGroup {
192239
address!.initialize(to: handle._cTensorHandle)
193240
}
194241

242+
public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }
243+
195244
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
196245
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
197246
}
247+
248+
public init<C: RandomAccessCollection>(
249+
_handles: C
250+
) where C.Element: _AnyTensorHandle {
251+
precondition(_handles.count == 1)
252+
self.init(handle: TensorHandle(handle: _handles[_handles.startIndex]))
253+
}
198254
}
199255

200256
extension Array: TensorArrayProtocol where Element: TensorGroup {
@@ -216,10 +272,32 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
216272
count: Int(count)).joined())
217273
}
218274

275+
public var _tensorHandles: ([_AnyTensorHandle]) {
276+
var result: [_AnyTensorHandle] = []
277+
result.reserveCapacity(Int(self._tensorHandleCount))
278+
for elem in self {
279+
result += elem._tensorHandles
280+
}
281+
return result
282+
}
283+
219284
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
220285
let size = count / Int(Element._tensorHandleCount)
221286
self = Array((0..<size).map { Element.init(
222287
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount)))
223288
})
224289
}
290+
291+
public init<C: RandomAccessCollection>(
292+
_handles: C
293+
) where C.Element: _AnyTensorHandle {
294+
let size = _handles.count / Int(Element._tensorHandleCount)
295+
self = (0..<size).map {
296+
let start = _handles.index(
297+
_handles.startIndex, offsetBy: $0 * Int(Element._tensorHandleCount))
298+
let end = _handles.index(
299+
start, offsetBy: Int(Element._tensorHandleCount))
300+
return Element.init(_handles: _handles[start..<end])
301+
}
302+
}
225303
}

Sources/TensorFlow/Core/TensorHandle.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
6262
self.handle = TFETensorHandle(_owning: cTensorHandle)
6363
}
6464

65+
public init(handle: _AnyTensorHandle) {
66+
self.handle = handle
67+
}
68+
6569
@usableFromInline
6670
init(copyingFromCTensor cTensor: CTensor) {
6771
let status = TF_NewStatus()
@@ -145,6 +149,11 @@ public struct ResourceHandle {
145149
init(owning cTensorHandle: CTensorHandle) {
146150
self.handle = TFETensorHandle(_owning: cTensorHandle)
147151
}
152+
153+
@usableFromInline
154+
init(handle: _AnyTensorHandle) {
155+
self.handle = handle
156+
}
148157
}
149158

150159
public struct VariantHandle {
@@ -157,4 +166,9 @@ public struct VariantHandle {
157166
init(owning cTensorHandle: CTensorHandle) {
158167
self.handle = TFETensorHandle(_owning: cTensorHandle)
159168
}
169+
170+
@usableFromInline
171+
init(handle: _AnyTensorHandle) {
172+
self.handle = handle
173+
}
160174
}

Sources/TensorFlow/Initializers.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,12 @@ public extension Tensor where Scalar: BinaryFloatingPoint {
376376
///
377377
init(
378378
randomUniform shape: TensorShape,
379-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
380-
Int64.random(in: Int64.min..<Int64.max))
379+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
380+
Int32.random(in: Int32.min..<Int32.max))
381381
) {
382382
self = Raw.statelessRandomUniform(
383383
shape: Tensor<Int32>((0..<shape.rank).map { Int32(shape[$0]) }),
384-
seed: Tensor<Int64>([seed.0, seed.1])
384+
seed: Tensor<Int32>([seed.0, seed.1])
385385
)
386386
}
387387

@@ -394,12 +394,12 @@ public extension Tensor where Scalar: BinaryFloatingPoint {
394394
///
395395
init(
396396
randomNormal shape: TensorShape,
397-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
398-
Int64.random(in: Int64.min..<Int64.max))
397+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
398+
Int32.random(in: Int32.min..<Int32.max))
399399
) {
400400
self = Raw.statelessRandomNormal(
401401
shape: Tensor<Int32>((0..<shape.rank).map { Int32(shape[$0]) }),
402-
seed: Tensor<Int64>([seed.0, seed.1])
402+
seed: Tensor<Int32>([seed.0, seed.1])
403403
)
404404
}
405405
}
@@ -475,8 +475,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
475475
///
476476
init(
477477
glorotUniform shape: TensorShape,
478-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
479-
Int64.random(in: Int64.min..<Int64.max))
478+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
479+
Int32.random(in: Int32.min..<Int32.max))
480480
) {
481481
let uniform = Tensor(randomUniform: shape, seed: seed)
482482
self = Tensor.glorot(fromStandardUniform: uniform, shape: shape)

Sources/TensorFlow/Layer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public protocol Layer: Differentiable & KeyPathIterable
3131
/// - Parameter input: The input to the layer.
3232
/// - Returns: The output.
3333
@differentiable
34-
func call(_ input: Input) -> Output
34+
func callAsFunction(_ input: Input) -> Output
3535
}
3636

3737
public extension Layer {

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public struct Conv1D<Scalar: TensorFlowFloatingPoint>: Layer {
5959
/// - Parameter input: The input to the layer `[batchCount, width, inputChannels]`.
6060
/// - Returns: The output `[batchCount, newWidth, outputChannels]`.
6161
@differentiable
62-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
62+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
6363
let conv2D = input.expandingShape(at: 1).convolved2D(
6464
withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding)
6565
return activation(conv2D.squeezingShape(at: 1) + bias)
@@ -116,8 +116,8 @@ public extension Conv1D {
116116
stride: Int = 1,
117117
padding: Padding = .valid,
118118
activation: @escaping Activation = identity,
119-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
120-
Int64.random(in: Int64.min..<Int64.max))
119+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
120+
Int32.random(in: Int32.min..<Int32.max))
121121
) {
122122
let filterTensorShape = TensorShape([
123123
filterShape.0, filterShape.1, filterShape.2])
@@ -177,7 +177,7 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
177177
/// - Parameter input: The input to the layer.
178178
/// - Returns: The output.
179179
@differentiable
180-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
180+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
181181
return activation(input.convolved2D(withFilter: filter,
182182
strides: (1, strides.0, strides.1, 1),
183183
padding: padding) + bias)
@@ -232,8 +232,8 @@ public extension Conv2D {
232232
strides: (Int, Int) = (1, 1),
233233
padding: Padding = .valid,
234234
activation: @escaping Activation = identity,
235-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
236-
Int64.random(in: Int64.min..<Int64.max))
235+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
236+
Int32.random(in: Int32.min..<Int32.max))
237237
) {
238238
let filterTensorShape = TensorShape([
239239
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
@@ -293,7 +293,7 @@ public struct Conv3D<Scalar: TensorFlowFloatingPoint>: Layer {
293293
/// - Parameter input: The input to the layer.
294294
/// - Returns: The output.
295295
@differentiable
296-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
296+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
297297
return activation(input.convolved3D(withFilter: filter,
298298
strides: (1, strides.0, strides.1, strides.2, 1),
299299
padding: padding) + bias)
@@ -348,8 +348,8 @@ public extension Conv3D {
348348
strides: (Int, Int, Int) = (1, 1, 1),
349349
padding: Padding = .valid,
350350
activation: @escaping Activation = identity,
351-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
352-
Int64.random(in: Int64.min..<Int64.max))
351+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
352+
Int32.random(in: Int32.min..<Int32.max))
353353
) {
354354
let filterTensorShape = TensorShape([
355355
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
@@ -411,7 +411,7 @@ public struct TransposedConv2D: Layer {
411411
/// - Parameter input: The input to the layer.
412412
/// - Returns: The output.
413413
@differentiable
414-
public func call(_ input: Tensor<Float>) -> Tensor<Float> {
414+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
415415
let batchSize = input.shape[0]
416416
let w = (input.shape[1] - (1 * paddingIndex)) *
417417
strides.0 + (filter.shape[0] * paddingIndex)
@@ -473,8 +473,8 @@ public extension TransposedConv2D {
473473
strides: (Int, Int) = (1, 1),
474474
padding: Padding = .valid,
475475
activation: @escaping Activation = identity,
476-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
477-
Int64.random(in: Int64.min..<Int64.max))
476+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
477+
Int32.random(in: Int32.min..<Int32.max))
478478
) {
479479
let filterTensorShape = TensorShape([
480480
filterShape.0, filterShape.1, filterShape.2, filterShape.3])

Sources/TensorFlow/Layers/Core.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public struct Dropout<Scalar: TensorFlowFloatingPoint>: Layer {
5353
/// - Parameter input: The input to the layer.
5454
/// - Returns: The output.
5555
@differentiable(vjp: _vjpApplied(to:))
56-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
56+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
5757
switch Context.local.learningPhase {
5858
case .training:
5959
return applyingTraining(to: input)
@@ -92,7 +92,7 @@ public struct Flatten<Scalar: TensorFlowFloatingPoint>: Layer {
9292
/// - Parameter input: The input to the layer.
9393
/// - Returns: The output.
9494
@differentiable
95-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
95+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
9696
let batchSize = input.shape[0]
9797
let remaining = input.shape[1..<input.rank].contiguousSize
9898
return input.reshaped(to: [batchSize, remaining])
@@ -128,7 +128,7 @@ public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
128128
/// - Parameter input: The input to the layer.
129129
/// - Returns: The output.
130130
@differentiable
131-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
131+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
132132
return input.reshaped(toShape: shape)
133133
}
134134
}
@@ -163,7 +163,7 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
163163
/// - Parameter input: The input to the layer.
164164
/// - Returns: The output.
165165
@differentiable
166-
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
166+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
167167
return activation(matmul(input, weight) + bias)
168168
}
169169
}
@@ -214,8 +214,8 @@ public extension Dense {
214214
inputSize: Int,
215215
outputSize: Int,
216216
activation: @escaping Activation = identity,
217-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
218-
Int64.random(in: Int64.min..<Int64.max))
217+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
218+
Int32.random(in: Int32.min..<Int32.max))
219219
) {
220220
self.init(weight: Tensor(glorotUniform: [inputSize, outputSize],
221221
seed: seed),

0 commit comments

Comments
 (0)