Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 30a1071

Browse files
authored
[Layering]: Create Tensor library & move Random. (#540)
This is the first of a series of changes to implement the [Layering Swift APIs Proposal](https://docs.google.com/document/d/1HO_sMhZJHxlDqw4Pjz4qva2s5yzgwxel-82yoK8O6L4/edit#). This first change creates the new Tensor library and moves the random number utilities over.
1 parent 5b845e3 commit 30a1071

File tree

15 files changed

+351
-257
lines changed

15 files changed

+351
-257
lines changed

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ RUN /swift-tensorflow-toolchain/usr/bin/swift test
6262
# TODO: Unify this with testing. (currently there is a demangling bug).
6363
RUN /swift-tensorflow-toolchain/usr/bin/swift build -Xswiftc -module-link-name -Xswiftc TensorFlow
6464
RUN cp /swift-apis/.build/debug/TensorFlow.swiftmodule /swift-tensorflow-toolchain/usr/lib/swift/linux/x86_64/
65+
RUN cp /swift-apis/.build/debug/Tensor.swiftmodule /swift-tensorflow-toolchain/usr/lib/swift/linux/x86_64/
6566
RUN cp /swift-apis/.build/debug/libTensorFlow.so /swift-tensorflow-toolchain/usr/lib/swift/linux/
67+
RUN cp /swift-apis/.build/debug/libTensor.so /swift-tensorflow-toolchain/usr/lib/swift/linux/
6668

6769
WORKDIR /
6870
RUN git clone https://github.com/tensorflow/swift-models.git

Package.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,29 @@ let package = Package(
2424
name: "TensorFlow",
2525
type: .dynamic,
2626
targets: ["TensorFlow"]),
27+
.library(
28+
name: "Tensor",
29+
type: .dynamic,
30+
targets: ["Tensor"]),
2731
],
2832
dependencies: [],
2933
targets: [
3034
.target(
31-
name: "TensorFlow",
35+
name: "Tensor",
3236
dependencies: []),
37+
.target(
38+
name: "TensorFlow",
39+
dependencies: ["Tensor"]),
3340
.target(
3441
name: "Experimental",
3542
dependencies: [],
3643
path: "Sources/third_party/Experimental"),
3744
.testTarget(
3845
name: "ExperimentalTests",
3946
dependencies: ["Experimental"]),
47+
.testTarget(
48+
name: "TensorTests",
49+
dependencies: ["Tensor"]),
4050
.testTarget(
4151
name: "TensorFlowTests",
4252
dependencies: ["TensorFlow"]),

Sources/TensorFlow/Random.swift renamed to Sources/Tensor/Random.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,18 @@ public func randomSeedForTensorFlow(using seed: TensorFlowSeed? = nil) -> Tensor
5656
///
5757
/// The `AnyRandomNumberGenerator` type forwards random number generating operations to an
5858
/// underlying random number generator, hiding its specific underlying type.
59-
internal struct AnyRandomNumberGenerator: RandomNumberGenerator {
59+
public struct AnyRandomNumberGenerator: RandomNumberGenerator {
6060
@usableFromInline
6161
var _rng: RandomNumberGenerator
6262

6363
/// - Parameter rng: A random number generator.
6464
@inlinable
65-
init(_ rng: RandomNumberGenerator) {
65+
public init(_ rng: RandomNumberGenerator) {
6666
self._rng = rng
6767
}
6868

6969
@inlinable
70-
mutating func next() -> UInt64 {
70+
public mutating func next() -> UInt64 {
7171
return self._rng.next()
7272
}
7373
}
@@ -400,7 +400,7 @@ public struct PhiloxRandomNumberGenerator: SeedableRandomNumberGenerator {
400400
return ctr
401401
}
402402

403-
internal init(uint64Seed seed: UInt64) {
403+
public init(uint64Seed seed: UInt64) {
404404
key = seed.vector2
405405
}
406406

Sources/Tensor/Utilities.swift

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
16+
import Darwin
17+
#else
18+
import Glibc
19+
#endif
20+
21+
//===------------------------------------------------------------------------------------------===//
22+
// Hashing
23+
//===------------------------------------------------------------------------------------------===//
24+
25+
internal extension FixedWidthInteger {
26+
init(bytes: ArraySlice<UInt8>, startingAt index: Int) {
27+
if bytes.isEmpty { self.init(0); return }
28+
let count = bytes.count
29+
self.init(0)
30+
for i in 0..<MemoryLayout<Self>.size {
31+
let j = (MemoryLayout<Self>.size - i - 1) * 8
32+
self |= count > 0 ? Self(bytes[index.advanced(by: i)]) << j : 0
33+
}
34+
}
35+
36+
func bytes(count byteCount: Int = MemoryLayout<Self>.size) -> [UInt8] {
37+
let actualByteCount = Swift.min(MemoryLayout<Self>.size, byteCount)
38+
var littleEndianValue = littleEndian
39+
return withUnsafePointer(to: &littleEndianValue) {
40+
$0.withMemoryRebound(to: UInt8.self, capacity: actualByteCount) { pointer in
41+
var bytes = [UInt8](repeating: 0, count: byteCount)
42+
for i in 0..<actualByteCount {
43+
bytes[byteCount - 1 - i] = (pointer + i).pointee
44+
}
45+
return bytes
46+
}
47+
}
48+
}
49+
}
50+
51+
internal extension Array where Element == UInt8 {
52+
/// - Note: The SHA1 hash is only 20 bytes long and so only the first 20 bytes of the returned
53+
/// `SIMD32<UInt8>` are non-zero.
54+
func sha1() -> SIMD32<UInt8> {
55+
let blockSize = 64
56+
var accumulated = self
57+
let lengthInBits = accumulated.count * 8
58+
let lengthBytes = lengthInBits.bytes(count: blockSize / 8)
59+
60+
// Step 1: Append padding.
61+
let msgLength = accumulated.count
62+
// Append one bit (`UInt8` with one bit) to the message.
63+
accumulated.append(0x80)
64+
// Append `0` bits until the length of `accumulated` in bits is 448 (mod 512).
65+
let max = blockSize * 7 / 8
66+
accumulated += [UInt8](
67+
repeating: 0,
68+
count: msgLength % blockSize < max ?
69+
max - 1 - (msgLength % blockSize) :
70+
blockSize + max - 1 - (msgLength % blockSize))
71+
72+
// Step 2: Append the message length as a 64-bit representation of `lengthInBits`.
73+
accumulated += lengthBytes
74+
75+
// Step 3: Process the array bytes.
76+
var accumulatedHash = SIMD8<UInt32>([
77+
0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0, 0x00, 0x00, 0x00])
78+
var index = 0
79+
while index < accumulated.count {
80+
let chunk = accumulated[index..<(index + blockSize)]
81+
index += blockSize
82+
83+
// Break chunk into sixteen 32-bit words w[j], 0 ≤ j ≤ 15, in big-endian format.
84+
// Extend the sixteen 32-bit words into eighty 32-bit words:
85+
var w = [UInt32](repeating: 0, count: 80)
86+
for x in w.indices {
87+
switch x {
88+
case 0...15:
89+
let start = chunk.startIndex.advanced(by: x * 4)
90+
w[x] = UInt32(bytes: chunk, startingAt: start)
91+
break
92+
default:
93+
let term = w[x - 3] ^ w[x - 8] ^ w[x - 14] ^ w[x - 16]
94+
w[x] = term << 1 ^ term >> 31
95+
break
96+
}
97+
}
98+
99+
var hashCopy = accumulatedHash
100+
for j in w.indices {
101+
var f: UInt32 = 0
102+
var k: UInt32 = 0
103+
switch j {
104+
case 0...19:
105+
f = (hashCopy[1] & hashCopy[2]) | (~hashCopy[1] & hashCopy[3])
106+
k = 0x5a827999
107+
break
108+
case 20...39:
109+
f = hashCopy[1] ^ hashCopy[2] ^ hashCopy[3]
110+
k = 0x6ed9eba1
111+
break
112+
case 40...59:
113+
f = (hashCopy[1] & hashCopy[2]) |
114+
(hashCopy[1] & hashCopy[3]) |
115+
(hashCopy[2] & hashCopy[3])
116+
k = 0x8f1bbcdc
117+
break
118+
default:
119+
f = hashCopy[1] ^ hashCopy[2] ^ hashCopy[3]
120+
k = 0xca62c1d6
121+
break
122+
}
123+
let temp = hashCopy[0] << 5 ^ hashCopy[0] >> 27
124+
let t0 = temp &+ f &+ hashCopy[4] &+ w[j] &+ k
125+
hashCopy[4] = hashCopy[3]
126+
hashCopy[3] = hashCopy[2]
127+
hashCopy[2] = hashCopy[1] << 30 ^ hashCopy[1] >> 2
128+
hashCopy[1] = hashCopy[0]
129+
hashCopy[0] = t0
130+
}
131+
accumulatedHash &+= hashCopy
132+
}
133+
134+
// Step 4: Return the computed hash.
135+
var result = SIMD32<UInt8>()
136+
var position = 0
137+
for index in accumulatedHash.indices {
138+
let h = accumulatedHash[index]
139+
result[position + 0] = UInt8((h >> 24) & 0xff)
140+
result[position + 1] = UInt8((h >> 16) & 0xff)
141+
result[position + 2] = UInt8((h >> 8) & 0xff)
142+
result[position + 3] = UInt8(h & 0xff)
143+
position += 4
144+
}
145+
146+
return result
147+
}
148+
149+
func sha512() -> SIMD64<UInt8> {
150+
// First we define some useful constants.
151+
let blockSize = 128
152+
let k: [UInt64] = [
153+
0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc,
154+
0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118,
155+
0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
156+
0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694,
157+
0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65,
158+
0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
159+
0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4,
160+
0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70,
161+
0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
162+
0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b,
163+
0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30,
164+
0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
165+
0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8,
166+
0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3,
167+
0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
168+
0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b,
169+
0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178,
170+
0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
171+
0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c,
172+
0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817]
173+
174+
var accumulated = self
175+
let lengthInBits = accumulated.count * 8
176+
let lengthBytes = lengthInBits.bytes(count: blockSize / 8)
177+
178+
// Step 1: Append padding.
179+
let msgLength = accumulated.count
180+
// Append one bit (`UInt8` with one bit) to the message.
181+
accumulated.append(0x80)
182+
// Append `0` bits until the length of `accumulated` in bits is 448 (mod 512).
183+
let max = blockSize * 7 / 8
184+
accumulated += [UInt8](
185+
repeating: 0,
186+
count: msgLength % blockSize < max ?
187+
max - 1 - (msgLength % blockSize) :
188+
blockSize + max - 1 - (msgLength % blockSize))
189+
190+
// Step 2: Append the message length as a 64-bit representation of `lengthInBits`.
191+
accumulated += lengthBytes
192+
193+
// Step 3: Process the array bytes.
194+
var accumulatedHash = SIMD8<UInt64>(
195+
0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1,
196+
0x510e527fade682d1, 0x9b05688c2b3e6c1f, 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179)
197+
var index = 0
198+
while index < accumulated.count {
199+
let chunk = accumulated[index..<(index + blockSize)]
200+
index += blockSize
201+
202+
// Break chunk into sixteen 64-bit words w[j], 0 ≤ j ≤ 15, in big-endian format.
203+
// Extend the sixteen 64-bit words into eighty 64-bit words:
204+
var w = [UInt64](repeating: 0, count: k.count)
205+
for x in w.indices {
206+
switch x {
207+
case 0...15:
208+
let start = chunk.startIndex.advanced(by: x * 8)
209+
w[x] = UInt64(bytes: chunk, startingAt: start)
210+
break
211+
default:
212+
let s0Term0 = ((w[x - 15] >> 1 ^ w[x - 15]) >> 6 ^ w[x - 15]) >> 1
213+
let s0Term1 = (w[x - 15] << 7 ^ w[x - 15]) << 56
214+
let s0 = s0Term0 ^ s0Term1
215+
let s1Term0 = ((w[x - 2] >> 42 ^ w[x - 2]) >> 13 ^ w[x - 2]) >> 6
216+
let s1Term1 = (w[x - 2] << 42 ^ w[x - 2]) << 3
217+
let s1 = s1Term0 ^ s1Term1
218+
w[x] = w[x - 16] &+ s0 &+ w[x - 7] &+ s1
219+
break
220+
}
221+
}
222+
223+
var hashCopy = accumulatedHash
224+
for j in w.indices {
225+
let s0Term0 = ((hashCopy[0] >> 5 ^ hashCopy[0]) >> 6 ^ hashCopy[0]) >> 28
226+
let s0Term1 = ((hashCopy[0] << 6 ^ hashCopy[0]) << 5 ^ hashCopy[0]) << 25
227+
let s0 = s0Term0 ^ s0Term1
228+
let s1Term0 = ((hashCopy[4] >> 23 ^ hashCopy[4]) >> 4 ^ hashCopy[4]) >> 14
229+
let s1Term1 = ((hashCopy[4] << 4 ^ hashCopy[4]) << 23 ^ hashCopy[4]) << 23
230+
let s1 = s1Term0 ^ s1Term1
231+
let maj = (hashCopy[0] & hashCopy[1]) ^
232+
(hashCopy[0] & hashCopy[2]) ^
233+
(hashCopy[1] & hashCopy[2])
234+
let t2 = s0 &+ maj
235+
let ch = (hashCopy[4] & hashCopy[5]) ^ (~hashCopy[4] & hashCopy[6])
236+
let t1 = hashCopy[7] &+ s1 &+ ch &+ k[j] &+ w[j]
237+
hashCopy[7] = hashCopy[6]
238+
hashCopy[6] = hashCopy[5]
239+
hashCopy[5] = hashCopy[4]
240+
hashCopy[4] = hashCopy[3] &+ t1
241+
hashCopy[3] = hashCopy[2]
242+
hashCopy[2] = hashCopy[1]
243+
hashCopy[1] = hashCopy[0]
244+
hashCopy[0] = t1 &+ t2
245+
}
246+
accumulatedHash &+= hashCopy
247+
}
248+
249+
// Step 4: Return the computed hash.
250+
var result = SIMD64<UInt8>()
251+
var position = 0
252+
for index in accumulatedHash.indices {
253+
let h = accumulatedHash[index]
254+
result[position + 0] = UInt8((h >> 56) & 0xff)
255+
result[position + 1] = UInt8((h >> 48) & 0xff)
256+
result[position + 2] = UInt8((h >> 40) & 0xff)
257+
result[position + 3] = UInt8((h >> 32) & 0xff)
258+
result[position + 4] = UInt8((h >> 24) & 0xff)
259+
result[position + 5] = UInt8((h >> 16) & 0xff)
260+
result[position + 6] = UInt8((h >> 8) & 0xff)
261+
result[position + 7] = UInt8(h & 0xff)
262+
position += 8
263+
}
264+
265+
return result
266+
}
267+
}

Sources/TensorFlow/Context.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import Darwin
1818
import Glibc
1919
#endif
2020

21+
import Tensor
22+
2123
/// A value that indicates the phase of using a machine learning model.
2224
public enum LearningPhase {
2325
case training

0 commit comments

Comments
 (0)