Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b698f1f

Browse files
authored
Enhanced the random seed generation function. (#328)
1 parent 2aed9be commit b698f1f

File tree

4 files changed

+205
-2
lines changed

4 files changed

+205
-2
lines changed

Sources/TensorFlow/Core/Utilities.swift

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,152 @@ extension UnsafeMutablePointer where Pointee == CTensorHandle? {
167167
self.init(unwrapped)
168168
}
169169
}
170+
171+
//===------------------------------------------------------------------------------------------===//
172+
// Hashing
173+
//===------------------------------------------------------------------------------------------===//
174+
175+
internal extension FixedWidthInteger {
176+
init(bytes: ArraySlice<UInt8>, startingAt index: Int) {
177+
if bytes.isEmpty { self.init(0); return }
178+
let count = bytes.count
179+
self.init(0)
180+
for i in 0..<MemoryLayout<Self>.size {
181+
let j = (MemoryLayout<Self>.size - i - 1) * 8
182+
self |= count > 0 ? Self(bytes[index.advanced(by: i)]) << j : 0
183+
}
184+
}
185+
186+
func bytes(count byteCount: Int = MemoryLayout<Self>.size) -> [UInt8] {
187+
var littleEndianValue = littleEndian
188+
return withUnsafePointer(to: &littleEndianValue) { pointer -> [UInt8] in
189+
let bytesPointer = UnsafeMutablePointer<UInt8>(OpaquePointer(pointer))
190+
var bytes = [UInt8](repeating: 0, count: byteCount)
191+
for i in 0..<Swift.min(MemoryLayout<Self>.size, byteCount) {
192+
bytes[byteCount - 1 - i] = (bytesPointer + i).pointee
193+
}
194+
return bytes
195+
}
196+
}
197+
}
198+
199+
internal extension Array where Element == UInt8 {
200+
func sha512() -> SIMD64<UInt8> {
201+
// First we define some useful constants.
202+
let blockSize = 128
203+
let k: [UInt64] = [
204+
0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc,
205+
0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118,
206+
0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
207+
0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694,
208+
0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65,
209+
0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
210+
0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4,
211+
0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70,
212+
0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
213+
0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b,
214+
0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30,
215+
0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
216+
0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8,
217+
0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3,
218+
0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
219+
0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b,
220+
0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178,
221+
0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
222+
0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c,
223+
0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817]
224+
225+
var accumulated = self
226+
let lengthInBits = accumulated.count * 8
227+
let lengthBytes = lengthInBits.bytes(count: blockSize / 8)
228+
229+
// Step 1: Append padding.
230+
let msgLength = accumulated.count
231+
// Append one bit (`UInt8` with one bit) to the message.
232+
accumulated.append(0x80)
233+
// Append `0` bits until the length of `accumulated` in bits is 448 (mod 512).
234+
let max = blockSize * 7 / 8
235+
accumulated += [UInt8](
236+
repeating: 0,
237+
count: msgLength % blockSize < max ?
238+
max - 1 - (msgLength % blockSize) :
239+
blockSize + max - 1 - (msgLength % blockSize))
240+
241+
// Step 2: Append the message length as a 64-bit representation of `lengthInBits`.
242+
accumulated += lengthBytes
243+
244+
// Step 3: Process the array bytes.
245+
var accumulatedHash = SIMD8<UInt64>(
246+
0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1,
247+
0x510e527fade682d1, 0x9b05688c2b3e6c1f, 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179)
248+
var index = 0
249+
while index < accumulated.count {
250+
let chunk = accumulated[index..<(index + blockSize)]
251+
index += blockSize
252+
253+
// Break chunk into sixteen 64-bit words w[j], 0 ≤ j ≤ 15, in big-endian format.
254+
// Extend the sixteen 64-bit words into eighty 64-bit words:
255+
var w = [UInt64](repeating: 0, count: k.count)
256+
for x in k.indices {
257+
switch x {
258+
case 0...15:
259+
let start = chunk.startIndex.advanced(by: x * 8)
260+
w[x] = UInt64(bytes: chunk, startingAt: start)
261+
break
262+
default:
263+
let s0Term0 = ((w[x - 15] >> 1 ^ w[x - 15]) >> 6 ^ w[x - 15]) >> 1
264+
let s0Term1 = (w[x - 15] << 7 ^ w[x - 15]) << 56
265+
let s0 = s0Term0 ^ s0Term1
266+
let s1Term0 = ((w[x - 2] >> 42 ^ w[x - 2]) >> 13 ^ w[x - 2]) >> 6
267+
let s1Term1 = (w[x - 2] << 42 ^ w[x - 2]) << 3
268+
let s1 = s1Term0 ^ s1Term1
269+
w[x] = w[x - 16] &+ s0 &+ w[x - 7] &+ s1
270+
break
271+
}
272+
}
273+
274+
var hashCopy = accumulatedHash
275+
for j in k.indices {
276+
let s0Term0 = ((hashCopy[0] >> 5 ^ hashCopy[0]) >> 6 ^ hashCopy[0]) >> 28
277+
let s0Term1 = ((hashCopy[0] << 6 ^ hashCopy[0]) << 5 ^ hashCopy[0]) << 25
278+
let s0 = s0Term0 ^ s0Term1
279+
let s1Term0 = ((hashCopy[4] >> 23 ^ hashCopy[4]) >> 4 ^ hashCopy[4]) >> 14
280+
let s1Term1 = ((hashCopy[4] << 4 ^ hashCopy[4]) << 23 ^ hashCopy[4]) << 23
281+
let s1 = s1Term0 ^ s1Term1
282+
let maj = (hashCopy[0] & hashCopy[1]) ^
283+
(hashCopy[0] & hashCopy[2]) ^
284+
(hashCopy[1] & hashCopy[2])
285+
let t2 = s0 &+ maj
286+
let ch = (hashCopy[4] & hashCopy[5]) ^ (~hashCopy[4] & hashCopy[6])
287+
let t1 = hashCopy[7] &+ s1 &+ ch &+ k[j] &+ w[j]
288+
hashCopy[7] = hashCopy[6]
289+
hashCopy[6] = hashCopy[5]
290+
hashCopy[5] = hashCopy[4]
291+
hashCopy[4] = hashCopy[3] &+ t1
292+
hashCopy[3] = hashCopy[2]
293+
hashCopy[2] = hashCopy[1]
294+
hashCopy[1] = hashCopy[0]
295+
hashCopy[0] = t1 &+ t2
296+
}
297+
accumulatedHash &+= hashCopy
298+
}
299+
300+
// Step 4: Return the computed hash.
301+
var result = SIMD64<UInt8>()
302+
var position = 0
303+
for index in accumulatedHash.indices {
304+
let h = accumulatedHash[index]
305+
result[position + 0] = UInt8((h >> 56) & 0xff)
306+
result[position + 1] = UInt8((h >> 48) & 0xff)
307+
result[position + 2] = UInt8((h >> 40) & 0xff)
308+
result[position + 3] = UInt8((h >> 32) & 0xff)
309+
result[position + 4] = UInt8((h >> 24) & 0xff)
310+
result[position + 5] = UInt8((h >> 16) & 0xff)
311+
result[position + 6] = UInt8((h >> 8) & 0xff)
312+
result[position + 7] = UInt8(h & 0xff)
313+
position += 8
314+
}
315+
316+
return result
317+
}
318+
}

Sources/TensorFlow/Random.swift

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,31 @@ import Glibc
1919
#endif
2020

2121
/// Generates a new random seed for TensorFlow.
22-
public func randomSeedForTensorFlow() -> (Int32, Int32) {
23-
(Int32.random(in: Int32.min..<Int32.max), Int32.random(in: Int32.min..<Int32.max))
22+
public func randomSeedForTensorFlow(using seed: (Int32, Int32)? = nil) -> (Int32, Int32) {
23+
var strongSeed = UInt64(0)
24+
if let s = seed {
25+
let bytes = (s.0.bytes() + s.1.bytes())[...]
26+
let singleSeed = UInt64(bytes: bytes, startingAt: bytes.startIndex)
27+
strongSeed = UInt64(pow(Double(singleSeed % 2), Double(8 * 8)))
28+
} else {
29+
strongSeed = UInt64.random(in: UInt64.min..<UInt64.max)
30+
}
31+
32+
// Many machine learning systems are likely to have many random number generators active at
33+
// once (e.g., in reinforcement learning we may have an environment running in multiple
34+
// processes). There is literature indicating that having linear correlations between seeds of
35+
// multiple PRNG's can correlate the outputs:
36+
// - http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers
37+
// - http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
38+
// - http://dl.acm.org/citation.cfm?id=1276928
39+
// Thus, for sanity we hash the generated seed before using it, This scheme is likely not
40+
// crypto-strength, but it should be good enough to get rid of simple correlations.
41+
// Reference: https://github.com/openai/gym/blob/master/gym/utils/seeding.py
42+
43+
let hash = strongSeed.bytes().sha512()
44+
let first = Int32(bytes: [hash[0], hash[1], hash[2], hash[3]], startingAt: 0)
45+
let second = Int32(bytes: [hash[4], hash[5], hash[6], hash[7]], startingAt: 0)
46+
return (first, second)
2447
}
2548

2649
//===------------------------------------------------------------------------------------------===//
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import XCTest
16+
@testable import TensorFlow
17+
18+
final class UtilitiesTests: XCTestCase {
19+
func testSHA512() {
20+
XCTAssertEqual(
21+
[UInt8](repeating: 0x61, count: 1000).sha512(),
22+
SIMD64<UInt8>([
23+
103, 186, 85, 53, 164, 110, 63, 134, 219, 251, 237, 140, 187, 175, 1, 37,
24+
199, 110, 213, 73, 255, 139, 11, 158, 3, 224, 200, 140, 249, 15, 166, 52,
25+
250, 123, 18, 180, 125, 119, 182, 148, 222, 72, 138, 206, 141, 154, 101, 150,
26+
125, 201, 109, 245, 153, 114, 125, 50, 146, 168, 217, 212, 71, 112, 156, 151]))
27+
}
28+
29+
static var allTests = [("testSHA512", testSHA512)]
30+
}

Tests/TensorFlowTests/XCTestManifests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import XCTest
1717
#if !os(macOS)
1818
public func allTests() -> [XCTestCaseEntry] {
1919
return [
20+
testCase(UtilitiesTests.allTests),
2021
testCase(LossTests.allTests),
2122
testCase(PRNGTests.allTests),
2223
testCase(TrivialModelTests.allTests),

0 commit comments

Comments
 (0)