Skip to content

Commit b8a41ae

Browse files
eaplataniospschuh
authored andcommitted
[TF] WIP: Removing all instances of #tfop. (#24425)
* Removed some instances of '#tfop'. * Removed all instances of '#tfop'. * Fixes to the build scripts for the new swift-bindings. * Disabled a couple of GPE tests. * Added back a couple of test helpers. * Minor tweak. * Minor tweaks. * Removed all uses of '.toHost' and '.toAccelerator'. * Added back some tests. * Bug fixes. * Updated the swift-bindings dependency.
1 parent 91fe4ee commit b8a41ae

20 files changed

+189
-630
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ Below is more information about TensorFlow-related build arguments.
151151
* Default value: None.
152152
* `tensorflow-swift-apis`: A path to the [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis) deep learning library repository.
153153
* Default value: `tensorflow-swift-apis` if the [tensorflow/swift-apis](https://github.com/tensorflow/swift-apis) repository is cloned. Otherwise, none.
154-
* `tensorflow-swift-bindings`: A generated TensorFlow Swift bindings file (`RawOpsGenerated.swift`) obtained from [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings).
155-
* Default value: `tensorflow-swift-bindings/RawOpsGenerated.swift` if the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository is cloned. Otherwise, none.
154+
* `tensorflow-swift-bindings`: A path to the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository.
155+
* Default value: `tensorflow-swift-bindings` if the [tensorflow/swift-bindings](https://github.com/tensorflow/swift-bindings) repository is cloned. Otherwise, none.
156156

157157
### Build systems
158158

stdlib/public/TensorFlow/ArrayOps.swift

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// This file contains some Array ops that cannot be properly handled by #tfop.
13+
// This file contains some custom raw ops.
1414
//
1515
// TODO: These should be deleted once we can properly generate raw ops for these.
1616
//
@@ -109,91 +109,4 @@ public extension Raw {
109109
}
110110
return out
111111
}
112-
113-
/// Splits a tensor into `numSplit` tensors along one dimension.
114-
///
115-
/// - Parameters:
116-
/// - splitDim: 0-D. The dimension along which to split. Must be in the range
117-
/// `[-rank(value), rank(value))`.
118-
/// - value: The tensor to split.
119-
/// - numSplit: The number of splits to create.
120-
///
121-
/// - Returns: Tensors whose shape matches that of `value`
122-
/// except along `axis`, where their sizes are
123-
/// `value.shape[axis] / numSplit`.
124-
@inlinable @inline(__always)
125-
static func split<T: TensorFlowScalar>(
126-
splitDim: Tensor<Int32>,
127-
value: Tensor<T>,
128-
numSplit: Int64
129-
) -> [Tensor<T>] {
130-
let s: CTFStatus = TF_NewStatus()
131-
defer { TF_DeleteStatus(s) }
132-
let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "Split", s)
133-
defer { TFE_DeleteOp(op) }
134-
let _ = _TFCOpAddInputFromTensorGroup(op, splitDim, s)
135-
let _ = _TFCOpAddInputFromTensorGroup(op, value, s)
136-
TFE_OpSetAttrInt(op, "num_split", numSplit)
137-
TFE_OpSetAttrType(op, "T", T.tensorFlowDataType._cDataType)
138-
var count: Int32 = Int32(numSplit)
139-
let buffer: UnsafeMutablePointer<CTensorHandle> =
140-
UnsafeMutablePointer.allocate(capacity: Int(count))
141-
defer { buffer.deallocate() }
142-
_TFCEagerExecute(op, UnsafeMutablePointer<CTensorHandle?>(buffer), &count, s)
143-
checkOk(s)
144-
145-
var out: [Tensor<T>] = []
146-
var cursor = buffer
147-
for _ in 0..<numSplit {
148-
out.append(Tensor<T>(handle: TensorHandle(_owning: cursor.pointee)))
149-
cursor = cursor.advanced(by: 1)
150-
}
151-
return out
152-
}
153-
154-
/// Splits a tensor into `numSplit` tensors along one dimension.
155-
///
156-
/// - Parameters:
157-
/// - value: The tensor to split.
158-
/// - sizeSplits: list containing the sizes of each output tensor along the split
159-
/// dimension. Must sum to the dimension of value along split_dim.
160-
/// Can contain one -1 indicating that dimension is to be inferred.
161-
/// - splitDim: 0-D. The dimension along which to split. Must be in the range
162-
/// `[-rank(value), rank(value))`.
163-
///
164-
/// - Returns: Tensors whose shape matches that of `value`
165-
/// except along `axis`, where their sizes are
166-
/// `size_splits[i]`.
167-
@inlinable @inline(__always)
168-
static func splitV<T: TensorFlowScalar, Tlen: BinaryInteger & TensorFlowScalar>(
169-
value: Tensor<T>,
170-
sizeSplits: Tensor<Tlen>,
171-
splitDim: Tensor<Int32>,
172-
numSplit: Int64
173-
) -> [Tensor<T>] {
174-
let s: CTFStatus = TF_NewStatus()
175-
defer { TF_DeleteStatus(s) }
176-
let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "SplitV", s)
177-
defer { TFE_DeleteOp(op) }
178-
let _ = _TFCOpAddInputFromTensorGroup(op, value, s)
179-
let _ = _TFCOpAddInputFromTensorGroup(op, sizeSplits, s)
180-
let _ = _TFCOpAddInputFromTensorGroup(op, splitDim, s)
181-
TFE_OpSetAttrInt(op, "num_split", numSplit)
182-
TFE_OpSetAttrType(op, "T", T.tensorFlowDataType._cDataType)
183-
TFE_OpSetAttrType(op, "Tlen", Tlen.tensorFlowDataType._cDataType)
184-
var count: Int32 = Int32(numSplit)
185-
let buffer: UnsafeMutablePointer<CTensorHandle> =
186-
UnsafeMutablePointer.allocate(capacity: Int(count))
187-
defer { buffer.deallocate() }
188-
_TFCEagerExecute(op, UnsafeMutablePointer<CTensorHandle?>(buffer), &count, s)
189-
checkOk(s)
190-
191-
var out: [Tensor<T>] = []
192-
var cursor = buffer
193-
for _ in 0..<numSplit {
194-
out.append(Tensor<T>(handle: TensorHandle(_owning: cursor.pointee)))
195-
cursor = cursor.advanced(by: 1)
196-
}
197-
return out
198-
}
199112
}

stdlib/public/TensorFlow/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ set(SOURCES
5656

5757
# Copy TensorFlow bindings file, if it exists.
5858
if (TENSORFLOW_SWIFT_BINDINGS)
59-
list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS}")
59+
file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES
60+
"${TENSORFLOW_SWIFT_BINDINGS}/*.swift")
61+
list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}")
6062
endif()
6163

6264
# Copy TensorFlow high-level API sources, if they exist.

stdlib/public/TensorFlow/Dataset.swift

Lines changed: 64 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,21 @@ public extension Dataset {
5858
@inlinable @inline(__always)
5959
init(randomSeed: Int64) {
6060
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed))
61-
self.init(
62-
_handle: #tfop("RandomDataset", seed1, seed2,
63-
output_types$dtype: Element._typeList,
64-
output_shapes: Element._unknownShapeList)
65-
)
61+
self.init(_handle: Raw.experimentalRandomDataset(
62+
seed: seed1,
63+
seed2: seed2,
64+
outputTypes: Element._typeList,
65+
outputShapes: Element._unknownShapeList))
6666
}
6767
}
6868

6969
public extension Dataset {
7070
/// Creates a dataset from a batch of elements as a tensor.
7171
@inlinable @inline(__always)
7272
init(elements: Element) {
73-
// A dataset creation op only runs on TF CPU.
74-
self.init(
75-
_handle: #tfop(
76-
"TensorSliceDataset", [elements],
77-
Toutput_types$dtype: Element._typeList,
78-
output_shapes: Element._unknownShapeList
79-
)
80-
)
73+
self.init(_handle: Raw.tensorSliceDataset(
74+
components: [elements],
75+
outputShapes: Element._unknownShapeList))
8176
}
8277
}
8378

@@ -87,10 +82,10 @@ extension Dataset : Sequence {
8782
/// Returns an iterator over the elements of this dataset.
8883
@inlinable @inline(__always)
8984
public func makeIterator() -> DatasetIterator<Element> {
90-
let resource: ResourceHandle =
91-
#tfop("AnonymousIterator", output_types$dtype: Element._typeList,
92-
output_shapes: Element._unknownShapeList)
93-
#tfop("MakeIterator", _handle, resource) as Void
85+
let resource = Raw.anonymousIterator(
86+
outputTypes: Element._typeList,
87+
outputShapes: Element._unknownShapeList)
88+
Raw.makeIterator(dataset: _handle, iterator: resource)
9489
return DatasetIterator(_handle: resource)
9590
}
9691
}
@@ -102,44 +97,43 @@ public extension Dataset {
10297
func map<ResultElement : TensorGroup>(
10398
_ transform: (Element) -> ResultElement
10499
) -> Dataset<ResultElement> {
105-
return Dataset<ResultElement>(
106-
_handle: #tfop(
107-
"MapDataset", _handle, [Tensor<Int32>(0)],
108-
f$func: _tffunc(transform),
109-
Targuments$dtype: [Int32.tensorFlowDataType],
110-
output_types$dtype: ResultElement._typeList,
111-
output_shapes: ResultElement._unknownShapeList
112-
)
113-
)
100+
return Dataset<ResultElement>(_handle: Raw.mapDataset(
101+
inputDataset: _handle,
102+
otherArguments: Tensor<Int32>(0),
103+
f: transform,
104+
outputTypes: ResultElement._typeList,
105+
outputShapes: ResultElement._unknownShapeList,
106+
useInterOpParallelism: true,
107+
preserveCardinality: false))
114108
}
115109

116110
@inlinable @inline(__always)
117-
func map<ResultElement : TensorGroup>(parallelCallCount: Int,
118-
_ transform: (Element) -> ResultElement) -> Dataset<ResultElement> {
119-
return Dataset<ResultElement>(
120-
_handle: #tfop("ParallelMapDataset", _handle, [Tensor<Int32>(0)],
121-
[Tensor<Int32>(Int32(parallelCallCount))],
122-
f$func: _tffunc(transform),
123-
Targuments$dtype: [Int32.tensorFlowDataType],
124-
output_types$dtype: ResultElement._typeList,
125-
output_shapes: ResultElement._unknownShapeList
126-
)
127-
)
111+
func map<ResultElement : TensorGroup>(
112+
parallelCallCount: Int,
113+
_ transform: (Element) -> ResultElement
114+
) -> Dataset<ResultElement> {
115+
return Dataset<ResultElement>(_handle: Raw.parallelMapDataset(
116+
inputDataset: _handle,
117+
otherArguments: Tensor<Int32>(0),
118+
numParallelCalls: Tensor<Int32>(Int32(parallelCallCount)),
119+
f: transform,
120+
outputTypes: ResultElement._typeList,
121+
outputShapes: ResultElement._unknownShapeList,
122+
useInterOpParallelism: true,
123+
sloppy: false,
124+
preserveCardinality: false))
128125
}
129126

130127
@inlinable @inline(__always)
131128
func filter(
132129
_ isIncluded: (Element) -> Tensor<Bool>
133130
) -> Dataset {
134-
return Dataset(
135-
_handle: #tfop(
136-
"FilterDataset", _handle, [Tensor<Int32>(0)],
137-
predicate$func: _tffunc(isIncluded),
138-
Targuments$dtype: [Int32.tensorFlowDataType],
139-
output_types$dtype: Element._typeList,
140-
output_shapes: Element._unknownShapeList
141-
)
142-
)
131+
return Dataset(_handle: Raw.filterDataset(
132+
inputDataset: _handle,
133+
otherArguments: Tensor<Int32>(0),
134+
predicate: isIncluded,
135+
outputTypes: Element._typeList,
136+
outputShapes: Element._unknownShapeList))
143137
}
144138
}
145139

@@ -149,24 +143,22 @@ public extension Dataset {
149143
sampleCount: Int, randomSeed: Int64
150144
) -> Dataset {
151145
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed))
152-
return Dataset(
153-
_handle: #tfop(
154-
"ShuffleDataset", _handle, Tensor(Int64(sampleCount)), seed1, seed2,
155-
output_types$dtype: Element._typeList,
156-
output_shapes: Element._unknownShapeList
157-
)
158-
)
146+
return Dataset(_handle: Raw.shuffleDataset(
147+
inputDataset: _handle,
148+
bufferSize: Tensor(Int64(sampleCount)),
149+
seed: seed1,
150+
seed2: seed2,
151+
outputTypes: Element._typeList,
152+
outputShapes: Element._unknownShapeList))
159153
}
160154

161155
@inlinable @inline(__always)
162156
func batched(_ batchSize: Int) -> Dataset {
163-
return Dataset(
164-
_handle: #tfop(
165-
"BatchDataset", _handle, Tensor(Int64(batchSize)),
166-
output_types$dtype: Element._typeList,
167-
output_shapes: Element._unknownShapeList
168-
)
169-
)
157+
return Dataset(_handle: Raw.batchDataset(
158+
inputDataset: _handle,
159+
batchSize: Tensor(Int64(batchSize)),
160+
outputTypes: Element._typeList,
161+
outputShapes: Element._unknownShapeList))
170162
}
171163
}
172164

@@ -186,16 +178,16 @@ extension DatasetIterator : IteratorProtocol {
186178
/// exists.
187179
@inlinable @inline(__always)
188180
public mutating func next() -> Element? {
189-
let optional: VariantHandle =
190-
#tfop("IteratorGetNextAsOptional", _handle,
191-
output_types$dtype: Element._typeList,
192-
output_shapes: Element._unknownShapeList)
193-
guard _TFGetScalarOrDie(#tfop("OptionalHasValue", optional)) else {
181+
let optional = Raw.iteratorGetNextAsOptional(
182+
iterator: _handle,
183+
outputTypes: Element._typeList,
184+
outputShapes: Element._unknownShapeList)
185+
guard Raw.optionalHasValue(optional: optional).scalarized() else {
194186
return nil
195187
}
196-
return #tfop("OptionalGetValue", optional,
197-
output_types$dtype: Element._typeList,
198-
output_shapes: Element._unknownShapeList) as Element
188+
return Raw.optionalGetValue(
189+
optional: optional,
190+
outputShapes: Element._unknownShapeList)
199191
}
200192
}
201193

@@ -217,9 +209,9 @@ public struct Zip2TensorGroup<T : TensorGroup, U : TensorGroup> : TensorGroup {
217209
public func zip<T : TensorGroup, U : TensorGroup>(
218210
_ dataset1: Dataset<T>, _ dataset2: Dataset<U>
219211
) -> Dataset<Zip2TensorGroup<T, U>> {
220-
let handle: VariantHandle = #tfop(
221-
"ZipDataset", Zip2TensorGroup(dataset1._handle, dataset2._handle),
222-
output_types$dtype: Zip2TensorGroup<T, U>._typeList,
223-
output_shapes: Zip2TensorGroup<T, U>._unknownShapeList)
212+
let handle = Raw.zipDataset(
213+
inputDatasets: [dataset1._handle, dataset2._handle],
214+
outputTypes: Zip2TensorGroup<T, U>._typeList,
215+
outputShapes: Zip2TensorGroup<T, U>._unknownShapeList)
224216
return Dataset(_handle: handle)
225217
}

stdlib/public/TensorFlow/Ops.swift

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,6 @@
2525
// we also define some helper function wrappers, e.g. to make things symmetric
2626
// and generally feel nice to use.
2727
//
28-
// The ops themselves are defined by the primitive #tfop(...) syntax, here are
29-
// some examples:
30-
// result = #tfop("Add", lhs, rhs)
31-
// result = #tfop("Const", dtype: Float.self, value$tensor: 4.0)
32-
//
33-
// The first parameter to this syntax is the TensorFlow op name as a string.
34-
// After that, the inputs are specified, and then attributes are specified
35-
// with their name as the keyword argument.
36-
//
37-
// Inputs and outputs must be of TensorHandle, ResourceHandle, or VariantHandle
38-
// type. These are magic types known to the compiler.
39-
//
4028

4129
infix operator ++ : AdditionPrecedence
4230

@@ -700,17 +688,15 @@ internal extension Tensor where Scalar : TensorFlowFloatingPoint {
700688
@inlinable @inline(__always)
701689
func _vjpConcatenated(with other: Tensor, alongAxis axis: Int)
702690
-> (Tensor, (Tensor) -> (Tensor, Tensor)) {
703-
let idx = axis < 0 ? axis + rank : axis
704-
let splits = Tensor<Int32>([shapeTensor[idx], other.shapeTensor[idx]])
691+
let posAxis = axis < 0 ? axis + rank: axis
692+
let splits = Tensor<Int32>([shapeTensor[posAxis], other.shapeTensor[posAxis]])
705693
return (concatenated(with: other, alongAxis: axis), { result in
706-
let ret: (TensorHandle<Scalar>, TensorHandle<Scalar>) = #tfop("SplitV",
707-
result,
708-
splits,
709-
Tensor<Int32>(Int32(axis)),
710-
num_split: Int64(2),
711-
T$dtype: Scalar.tensorFlowDataType,
712-
Tlen$dtype: Int32.tensorFlowDataType)
713-
return (Tensor(handle: ret.0), Tensor(handle: ret.1))
694+
let gradients = Raw.splitV(
695+
value: result,
696+
sizeSplits: splits,
697+
splitDim: Tensor<Int32>(Int32(axis)),
698+
numSplit: Int64(splits.shape[0]))
699+
return (gradients[0], gradients[1])
714700
})
715701
}
716702
}

stdlib/public/TensorFlow/StringOps.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ public extension StringTensor {
2323
/// - Note: `elementsEqual` supports broadcasting.
2424
@inlinable @inline(__always)
2525
func elementsEqual(_ other: StringTensor) -> Tensor<Bool> {
26-
return #tfop("Equal", self.handle, other.handle,
27-
T$dtype: String.tensorFlowDataType)
26+
return Raw.equal(self, other)
2827
}
2928
}

0 commit comments

Comments
 (0)