Skip to content

Commit ce4dfd3

Browse files
committed
Changes to support the new swift-bindings.
1 parent df0ec40 commit ce4dfd3

File tree

6 files changed

+33
-15
lines changed

6 files changed

+33
-15
lines changed

stdlib/public/TensorFlow/StringOps.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ public extension StringTensor {
2323
/// - Note: `elementsEqual` supports broadcasting.
2424
@inlinable @inline(__always)
2525
func elementsEqual(_ other: StringTensor) -> Tensor<Bool> {
26-
return #tfop("Equal", self.handle, other.handle,
27-
T$dtype: String.tensorFlowDataType)
26+
return Raw.equal(self, other)
2827
}
2928
}

stdlib/public/TensorFlow/Tensor.swift

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,7 @@ extension _TensorElementLiteral : ExpressibleByArrayLiteral {
416416
public typealias ArrayLiteralElement = _TensorElementLiteral<Scalar>
417417
@inlinable @inline(__always)
418418
public init(arrayLiteral elements: _TensorElementLiteral<Scalar>...) {
419-
// Attr T (non-optional in the op definition) need not be specified when we
420-
// run the op as part of a graph function, but need to be specified when we
421-
// run it via eager C API.
422-
let handle: TensorHandle<Scalar> = #tfop("Pack", elements,
423-
T$dtype: Scalar.tensorFlowDataType)
424-
tensor = Tensor(handle: handle)
419+
tensor = Raw.pack(elements.map { $0.tensor })
425420
}
426421
}
427422

@@ -436,8 +431,7 @@ extension Tensor : ExpressibleByArrayLiteral {
436431
internal init(
437432
_tensorElementLiterals elements: [_TensorElementLiteral<Scalar>]
438433
) {
439-
self.init(handle: #tfop("Pack", elements,
440-
T$dtype: Scalar.tensorFlowDataType))
434+
self = Raw.pack(elements.map { $0.tensor })
441435
}
442436

443437
/// Creates a tensor initialized with the given elements.

stdlib/public/TensorFlow/TensorGroup.swift

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ public protocol TensorArrayProtocol {
2929
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)
3030

3131
var _tensorHandleCount: Int32 { get }
32+
33+
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
3234
}
3335

3436
/// A protocol representing types that can be mapped to and from
@@ -67,6 +69,11 @@ public extension TensorGroup {
6769
static var _unknownShapeList: [TensorShape?] {
6870
return Array(repeating: nil, count: _typeList.count)
6971
}
72+
73+
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
74+
precondition(count == Self._typeList.count)
75+
self.init(_owning: tensorHandles)
76+
}
7077
}
7178

7279
//===----------------------------------------------------------------------===//
@@ -199,7 +206,7 @@ extension StringTensor : TensorGroup {
199206
}
200207
}
201208

202-
extension Array : TensorArrayProtocol where Element : TensorArrayProtocol {
209+
extension Array : TensorArrayProtocol where Element : TensorGroup {
203210
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
204211
var ptr = address
205212
for elem in self {
@@ -213,4 +220,11 @@ extension Array : TensorArrayProtocol where Element : TensorArrayProtocol {
213220
for elem in self { count += elem._tensorHandleCount }
214221
return count
215222
}
223+
224+
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
225+
let size = count / Int(Element._tensorHandleCount)
226+
self = Array((0..<size).map { Element.init(
227+
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount)))
228+
})
229+
}
216230
}

test/TensorFlowRuntime/tracer.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,25 @@ TracerTests.testAllBackends("Advanced") {
168168
var model: Model = [Tensor<Float>(1.0), Tensor<Float>(2.0)]
169169
var optimizer: Optimizer = [Tensor<Float>(1.0), Tensor<Float>(2.0)]
170170

171+
public init() {}
172+
173+
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
174+
self.model = [
175+
Tensor<Float>(_owning: tensorHandles),
176+
Tensor<Float>(_owning: tensorHandles?.advanced(by: 1))]
177+
self.optimizer = [
178+
Tensor<Float>(_owning: tensorHandles?.advanced(by: 2)),
179+
Tensor<Float>(_owning: tensorHandles?.advanced(by: 3))]
180+
}
181+
171182
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
172183
print("Calling State._unpackTensorHandles().")
173184
var ptr = address
174185
model._unpackTensorHandles(into: ptr)
175186
ptr = ptr!.advanced(by: Int(model._tensorHandleCount))
176187
optimizer._unpackTensorHandles(into: ptr)
177188
}
189+
178190
public var _tensorHandleCount: Int32 {
179191
return model._tensorHandleCount + optimizer._tensorHandleCount
180192
}

utils/build-script-impl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ KNOWN_SETTINGS=(
280280
tensorflow-host-include-dir "" "Path to host TensorFlow headers"
281281
tensorflow-target-include-dir "" "Path to target Tensorflow headers"
282282
tensorflow-target-lib-dir "" "Path to target TensorFlow libraries"
283-
tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings file"
283+
tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings repository"
284284
tensorflow-swift-apis "" "Path to TensorFlow deep learning library repository"
285285
)
286286

@@ -2476,7 +2476,7 @@ for host in "${ALL_HOSTS[@]}"; do
24762476

24772477
# Handle TensorFlow Swift bindings file.
24782478
if [[ ! "${TENSORFLOW_SWIFT_BINDINGS}" && -d "${TENSORFLOW_SWIFT_BINDINGS_DIR}" ]] ; then
2479-
TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}/RawOpsGenerated.swift"
2479+
TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}"
24802480
fi
24812481
if [[ "${TENSORFLOW_SWIFT_BINDINGS}" ]] ; then
24822482
cmake_options=(

utils/build_swift/driver_arguments.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,7 @@ def create_argument_parser():
974974
'Used for linking Swift programs.')
975975
option('--tensorflow-swift-bindings', store_path,
976976
default=None,
977-
help='Path to a TensorFlow Swift bindings file '
978-
'(RawOpsGenerated.swift).')
977+
help='Path to a TensorFlow Swift bindings repository.')
979978
option('--tensorflow-swift-apis', store_path,
980979
default=None,
981980
help='Path to a TensorFlow deep learning library repository.')

0 commit comments

Comments
 (0)