Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6d525f5

Browse files
authored
Use multiple TFE_OpAddInput calls instead of one TFE_OpAddIputList (#375)
1 parent 87ba3c4 commit 6d525f5

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

Sources/TensorFlow/Bindings/EagerExecution.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ internal struct TFE_Op: TFTensorOperation {
7070
defer { buffer.deallocate() }
7171
let pointer = UnsafeMutablePointer<OpaquePointer?>(buffer.baseAddress)
7272
input._unpackTensorHandles(into: buffer.baseAddress)
73-
TFE_OpAddInputList(op, pointer, count, status)
74-
// TODO: checkOk(status)
73+
for i in 0..<Int(count) {
74+
TFE_OpAddInput(op, buffer[i], status)
75+
checkOk(status)
76+
}
7577
}
7678

7779
@inlinable @inline(__always)

Sources/TensorFlow/Bindings/EagerExecution.swift.gyb

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ internal struct TFE_Op: TFTensorOperation {
7070
defer { buffer.deallocate() }
7171
let pointer = UnsafeMutablePointer<OpaquePointer?>(buffer.baseAddress)
7272
input._unpackTensorHandles(into: buffer.baseAddress)
73-
TFE_OpAddInputList(op, pointer, count, status)
74-
// TODO: checkOk(status)
73+
for i in 0..<Int(count) {
74+
TFE_OpAddInput(op, buffer[i], status)
75+
checkOk(status)
76+
}
7577
}
7678

7779
@inlinable @inline(__always)
@@ -271,7 +273,12 @@ internal struct TFE_Op: TFTensorOperation {
271273
_ name: String,
272274
_ value: (In) -> Out
273275
) {
274-
_tffunc(value).utf8CString.withUnsafeBufferPointer { buffer in
276+
updateAttribute(name, _TensorFunctionPointer(name: _tffunc(value)))
277+
}
278+
279+
@inlinable @inline(__always)
280+
internal func updateAttribute(_ name: String, _ value: _TensorFunctionPointer) {
281+
value.name.utf8CString.withUnsafeBufferPointer { buffer in
275282
// utf8CString is null-terminated; TFE_OpSetAttrFunctionName wants
276283
// non-null-terminated.
277284
TFE_OpSetAttrFunctionName(op, name, buffer.baseAddress, buffer.count - 1)

Tests/TensorFlowTests/OperatorTests/DatasetTests.swift

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ final class DatasetTests: XCTestCase {
125125
XCTAssertEqual(iterator.next()!.scalars, [4])
126126
}
127127

128-
/*
129128
func testDoubleValueDatasetIteration() {
130129
let scalars1 = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
131130
let scalars2 = Tensor<Int32>(rangeFrom: 5, to: 10, stride: 1)
@@ -138,7 +137,6 @@ final class DatasetTests: XCTestCase {
138137
i += 1
139138
}
140139
}
141-
*/
142140

143141
static var allTests = [
144142
("testMultiValue", testMultiValue),
@@ -149,8 +147,6 @@ final class DatasetTests: XCTestCase {
149147
("testParallelMap", testParallelMap),
150148
("testMapToDifferentType", testMapToDifferentType),
151149
("testSingleValueBatched", testSingleValueBatched),
152-
// Currently broken even in TensorFlow ...
153-
// This will be easier to fix once everything is moved ...
154-
// ("testDoubleValueDatasetIteration", testDoubleValueDatasetIteration),
150+
("testDoubleValueDatasetIteration", testDoubleValueDatasetIteration),
155151
]
156152
}

0 commit comments

Comments
 (0)