Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit e5a288b

Browse files
authored
Move Dataset.swift over from TensorFlow. (#133)
1 parent f111ede commit e5a288b

File tree

3 files changed

+373
-0
lines changed

3 files changed

+373
-0
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
//===-- Dataset.swift -----------------------------------------*- swift -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// The dataset API.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
/// The default graph seed.
18+
///
19+
/// - Note: See TensorFlow's `python.framework.random_seed.DEFAULT_GRAPH_SEED`.
20+
@usableFromInline let _defaultGraphSeed: Int64 = 87654321
21+
22+
/// Returns the local seeds an operation should use given an op-specific seed.
23+
///
24+
/// Given operation-specific seed, `seed`, this helper function returns two
25+
/// seeds derived from graph-level and op-level seeds. Many random operations
26+
/// internally use the two seeds to allow user to change the seed globally for a
27+
/// graph, or for only specific operations.
28+
///
29+
/// - Note: See TensorFlow's `python.framework.random_seed.get_seed`.
30+
///
31+
// TODO: There's no support for TF's "global seed" yet, so we always use the
32+
// default graph seed as the first seed. Need to investigate the best way to
33+
// model TF's "global seed".
34+
@usableFromInline @inline(__always)
35+
func _tensorSeeds(_ seed: Tensor<Int64>) -> (Tensor<Int64>, Tensor<Int64>) {
36+
return (Tensor(_defaultGraphSeed), seed)
37+
}
38+
39+
//===----------------------------------------------------------------------===//
40+
// Single value dataset
41+
//===----------------------------------------------------------------------===//
42+
43+
/// Represents a potentially large set of elements.
44+
///
45+
/// A `Dataset` can be used to represent an input pipeline as a collection of
46+
/// element tensors.
47+
@_fixed_layout
48+
public struct Dataset<Element : TensorGroup> {
49+
public let _handle: VariantHandle
50+
51+
@inlinable
52+
public init(_handle: VariantHandle) {
53+
self._handle = _handle
54+
}
55+
}
56+
57+
public extension Dataset {
58+
@inlinable
59+
init(randomSeed: Int64) {
60+
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed))
61+
self.init(_handle: Raw.experimentalRandomDataset(
62+
seed: seed1,
63+
seed2: seed2,
64+
outputTypes: Element._typeList,
65+
outputShapes: Element._unknownShapeList))
66+
}
67+
}
68+
69+
public extension Dataset {
70+
/// Creates a dataset from a batch of elements as a tensor.
71+
@inlinable
72+
init(elements: Element) {
73+
self.init(_handle: Raw.tensorSliceDataset(
74+
components: [elements],
75+
outputShapes: Element._unknownShapeList))
76+
}
77+
}
78+
79+
extension Dataset : Sequence {
80+
public typealias Iterator = DatasetIterator<Element>
81+
82+
/// Returns an iterator over the elements of this dataset.
83+
@inlinable
84+
public func makeIterator() -> DatasetIterator<Element> {
85+
let resource = Raw.anonymousIterator(
86+
outputTypes: Element._typeList,
87+
outputShapes: Element._unknownShapeList)
88+
Raw.makeIterator(dataset: _handle, iterator: resource)
89+
return DatasetIterator(_handle: resource)
90+
}
91+
}
92+
93+
public extension Dataset {
94+
// Note that this Dataset API implementation uses an experimental tracing
95+
// feature, which is not robust and does not have great diagnostics yet.
96+
@inlinable
97+
func map<ResultElement : TensorGroup>(
98+
_ transform: (Element) -> ResultElement
99+
) -> Dataset<ResultElement> {
100+
return Dataset<ResultElement>(_handle: Raw.mapDataset(
101+
inputDataset: _handle,
102+
otherArguments: Tensor<Int32>(0),
103+
f: transform,
104+
outputTypes: ResultElement._typeList,
105+
outputShapes: ResultElement._unknownShapeList,
106+
useInterOpParallelism: true,
107+
preserveCardinality: false))
108+
}
109+
110+
@inlinable
111+
func map<ResultElement : TensorGroup>(
112+
parallelCallCount: Int,
113+
_ transform: (Element) -> ResultElement
114+
) -> Dataset<ResultElement> {
115+
return Dataset<ResultElement>(_handle: Raw.parallelMapDataset(
116+
inputDataset: _handle,
117+
otherArguments: Tensor<Int32>(0),
118+
numParallelCalls: Tensor<Int32>(Int32(parallelCallCount)),
119+
f: transform,
120+
outputTypes: ResultElement._typeList,
121+
outputShapes: ResultElement._unknownShapeList,
122+
useInterOpParallelism: true,
123+
sloppy: false,
124+
preserveCardinality: false))
125+
}
126+
127+
@inlinable
128+
func filter(
129+
_ isIncluded: (Element) -> Tensor<Bool>
130+
) -> Dataset {
131+
return Dataset(_handle: Raw.filterDataset(
132+
inputDataset: _handle,
133+
otherArguments: Tensor<Int32>(0),
134+
predicate: isIncluded,
135+
outputTypes: Element._typeList,
136+
outputShapes: Element._unknownShapeList))
137+
}
138+
}
139+
140+
public extension Dataset {
141+
@inlinable
142+
func shuffled(
143+
sampleCount: Int, randomSeed: Int64
144+
) -> Dataset {
145+
let (seed1, seed2) = _tensorSeeds(Tensor(randomSeed))
146+
return Dataset(_handle: Raw.shuffleDataset(
147+
inputDataset: _handle,
148+
bufferSize: Tensor(Int64(sampleCount)),
149+
seed: seed1,
150+
seed2: seed2,
151+
outputTypes: Element._typeList,
152+
outputShapes: Element._unknownShapeList))
153+
}
154+
155+
@inlinable
156+
func batched(_ batchSize: Int) -> Dataset {
157+
return Dataset(_handle: Raw.batchDataset(
158+
inputDataset: _handle,
159+
batchSize: Tensor(Int64(batchSize)),
160+
outputTypes: Element._typeList,
161+
outputShapes: Element._unknownShapeList))
162+
}
163+
}
164+
165+
/// The type that allows iteration over a dataset's elements.
166+
@_fixed_layout
167+
public struct DatasetIterator<Element : TensorGroup> {
168+
@usableFromInline let _handle: ResourceHandle
169+
170+
@usableFromInline
171+
internal init(_handle: ResourceHandle) {
172+
self._handle = _handle
173+
}
174+
}
175+
176+
extension DatasetIterator : IteratorProtocol {
177+
/// Advances to the next element and returns it, or `nil` if no next element
178+
/// exists.
179+
@inlinable
180+
public mutating func next() -> Element? {
181+
let optional = Raw.iteratorGetNextAsOptional(
182+
iterator: _handle,
183+
outputTypes: Element._typeList,
184+
outputShapes: Element._unknownShapeList)
185+
guard Raw.optionalHasValue(optional: optional).scalarized() else {
186+
return nil
187+
}
188+
return Raw.optionalGetValue(
189+
optional: optional,
190+
outputShapes: Element._unknownShapeList)
191+
}
192+
}
193+
194+
/// A 2-tuple-like struct that conforms to TensorGroup that represents a tuple
195+
/// of 2 types conforming to TensorGroup.
196+
@_fixed_layout
197+
public struct Zip2TensorGroup<T : TensorGroup, U : TensorGroup> : TensorGroup {
198+
public var first: T
199+
public var second: U
200+
201+
public init(_ first: T, _ second: U) {
202+
self.first = first
203+
self.second = second
204+
}
205+
}
206+
207+
@inlinable
208+
public func zip<T : TensorGroup, U : TensorGroup>(
209+
_ dataset1: Dataset<T>, _ dataset2: Dataset<U>
210+
) -> Dataset<Zip2TensorGroup<T, U>> {
211+
let handle = Raw.zipDataset(
212+
inputDatasets: [dataset1._handle, dataset2._handle],
213+
outputTypes: Zip2TensorGroup<T, U>._typeList,
214+
outputShapes: Zip2TensorGroup<T, U>._unknownShapeList)
215+
return Dataset(_handle: handle)
216+
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
import DeepLearning
17+
18+
struct SimpleOutput : TensorGroup {
19+
let a: TensorHandle<Int32>
20+
let b: TensorHandle<Int32>
21+
}
22+
23+
final class DatasetTests: XCTestCase {
24+
func testMultiValue() {
25+
let elements1: Tensor<Int32> = [0, 1, 2]
26+
let elements2: Tensor<Int32> = [10, 11, 12]
27+
let outputTypes = [Int32.tensorFlowDataType, Int32.tensorFlowDataType]
28+
let outputShapes: [TensorShape?] = [nil, nil]
29+
let dataset: VariantHandle = Raw.tensorSliceDataset(
30+
components: [elements1, elements2],
31+
outputShapes: outputShapes
32+
)
33+
let iterator: ResourceHandle = Raw.iteratorV2(sharedName: "blah",
34+
container: "earth", outputTypes: outputTypes, outputShapes: outputShapes
35+
)
36+
Raw.makeIterator(dataset: dataset, iterator: iterator)
37+
var next: SimpleOutput = Raw.iteratorGetNext(
38+
iterator: iterator, outputShapes: outputShapes
39+
)
40+
XCTAssertEqual(0, Tensor(handle: next.a).scalarized())
41+
XCTAssertEqual(10, Tensor(handle: next.b).scalarized())
42+
next = Raw.iteratorGetNext(
43+
iterator: iterator, outputShapes: outputShapes
44+
)
45+
XCTAssertEqual(1, Tensor(handle: next.a).scalarized())
46+
XCTAssertEqual(11, Tensor(handle: next.b).scalarized())
47+
next = Raw.iteratorGetNext(
48+
iterator: iterator, outputShapes: outputShapes
49+
)
50+
XCTAssertEqual(2, Tensor(handle: next.a).scalarized())
51+
XCTAssertEqual(12, Tensor(handle: next.b).scalarized())
52+
}
53+
54+
func testSingleValueManualIterator() {
55+
// [[1], [2], [3], [4], [5]]
56+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
57+
.reshaped(to: [5, 1])
58+
let dataset = Dataset(elements: scalars)
59+
var iterator = dataset.makeIterator()
60+
var i: Int = 0
61+
while let item = iterator.next() {
62+
XCTAssertEqual(scalars[i].array, item.array)
63+
i += 1
64+
}
65+
}
66+
67+
func testDatasetIteration() {
68+
// [[1], [2], [3], [4], [5]]
69+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
70+
.reshaped(to: [5, 1])
71+
let dataset = Dataset(elements: scalars)
72+
var i: Int = 0
73+
for item in dataset {
74+
XCTAssertEqual(scalars[i].array, item.array)
75+
i += 1
76+
}
77+
}
78+
79+
func testSingleValueTransformations() {
80+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
81+
let dataset = Dataset(elements: scalars)
82+
let shuffled = dataset.shuffled(sampleCount: 5, randomSeed: 42)
83+
XCTAssertEqual([0, 4, 1, 3, 2], shuffled.map { $0.scalar! })
84+
}
85+
86+
func testSingleValueHOFs() {
87+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
88+
let dataset = Dataset(elements: scalars)
89+
let addedOne: Dataset = dataset.map { $0 + 1 }
90+
XCTAssertEqual([1, 2, 3, 4, 5], addedOne.flatMap { $0.scalars })
91+
// Use '.==' in the following closure to avoid any conversions to
92+
// host data types, which is not handled correctly in tracing.
93+
let evens: Dataset = dataset.filter { Tensor($0 % 2) .== Tensor(0) }
94+
XCTAssertEqual([0, 2, 4], evens.flatMap { $0.scalars })
95+
}
96+
97+
func testParallelMap() {
98+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
99+
let dataset = Dataset(elements: scalars)
100+
let addedOne: Dataset = dataset.map(parallelCallCount: 5) { $0 + 1 }
101+
XCTAssertEqual([1, 2, 3, 4, 5], addedOne.flatMap { $0.scalars })
102+
// Use '.==' in the following closure to avoid any conversions to
103+
// host data types, which is not handled correctly in tracing.
104+
let evens: Dataset = dataset.filter { Tensor($0 % 2) .== Tensor(0) }
105+
XCTAssertEqual([0, 2, 4], evens.flatMap { $0.scalars })
106+
}
107+
108+
func testMapToDifferentType() {
109+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
110+
let dataset = Dataset(elements: scalars)
111+
let shuffled = dataset.shuffled(sampleCount: 5, randomSeed: 42)
112+
XCTAssertEqual([0, 4, 1, 3, 2], shuffled.map { $0.scalar! })
113+
let evens = shuffled.map { Tensor($0 % 2) .== Tensor(0) }
114+
XCTAssertEqual([true, true, false, false, true], evens.map { $0.scalar! })
115+
}
116+
117+
func testSingleValueBatched() {
118+
let scalars = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
119+
let dataset = Dataset(elements: scalars)
120+
let batched = dataset.batched(2)
121+
122+
var iterator = batched.makeIterator()
123+
XCTAssertEqual([0, 1], iterator.next()!.scalars)
124+
XCTAssertEqual([2, 3], iterator.next()!.scalars)
125+
XCTAssertEqual([4], iterator.next()!.scalars)
126+
}
127+
128+
/*
129+
func testDoubleValueDatasetIteration() {
130+
let scalars1 = Tensor<Float>(rangeFrom: 0, to: 5, stride: 1)
131+
let scalars2 = Tensor<Int32>(rangeFrom: 5, to: 10, stride: 1)
132+
let datasetLeft = Dataset(elements: scalars1)
133+
let datasetRight = Dataset(elements: scalars2)
134+
var i: Int = 0
135+
for pair in zip(datasetLeft, datasetRight) {
136+
XCTAssertEqual(scalars1[i].array, pair.first.array)
137+
XCTAssertEqual(scalars2[i].array, pair.second.array)
138+
i += 1
139+
}
140+
}
141+
*/
142+
143+
static var allTests = [
144+
("testMultiValue", testMultiValue),
145+
("testSingleValueManualIterator", testSingleValueManualIterator),
146+
("testDatasetIteration", testDatasetIteration),
147+
("testSingleValueTransformations", testSingleValueTransformations),
148+
("testSingleValueHOFs", testSingleValueHOFs),
149+
("testParallelMap", testParallelMap),
150+
("testMapToDifferentType", testMapToDifferentType),
151+
("testSingleValueBatched", testSingleValueBatched),
152+
// Currently broken even in TensorFlow ...
153+
// This will be easier to fix once everything is moved ...
154+
// ("testDoubleValueDatasetIteration", testDoubleValueDatasetIteration),
155+
]
156+
}

Tests/DeepLearningTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public func allTests() -> [XCTestCaseEntry] {
2323
testCase(SequentialTests.allTests),
2424
testCase(LayerTests.allTests),
2525
testCase(TensorTests.allTests),
26+
testCase(DatasetTests.allTests),
2627
]
2728
}
2829
#endif

0 commit comments

Comments
 (0)