Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 23cdfd3

Browse files
authored
Make seed type to be Int32 so that it is compatible with TPUs. (#169)
Otherwise, we get the following error: ``` (OpKernel was found, but attributes didn't match) Requested Attributes: T=DT_INT32, Tseed=DT_INT64, dtype=DT_FLOAT . Registered: device='TPU'; Tseed in [DT_INT32]; T in [DT_INT32, DT_INT64]; dtype in [DT_FLOAT, DT_BFLOAT16] ```
1 parent 738b7a5 commit 23cdfd3

File tree

7 files changed

+30
-30
lines changed

7 files changed

+30
-30
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,12 @@ public extension Tensor where Scalar: BinaryFloatingPoint {
376376
///
377377
init(
378378
randomUniform shape: TensorShape,
379-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
380-
Int64.random(in: Int64.min..<Int64.max))
379+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
380+
Int32.random(in: Int32.min..<Int32.max))
381381
) {
382382
self = Raw.statelessRandomUniform(
383383
shape: Tensor<Int32>((0..<shape.rank).map { Int32(shape[$0]) }),
384-
seed: Tensor<Int64>([seed.0, seed.1])
384+
seed: Tensor<Int32>([seed.0, seed.1])
385385
)
386386
}
387387

@@ -394,12 +394,12 @@ public extension Tensor where Scalar: BinaryFloatingPoint {
394394
///
395395
init(
396396
randomNormal shape: TensorShape,
397-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
398-
Int64.random(in: Int64.min..<Int64.max))
397+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
398+
Int32.random(in: Int32.min..<Int32.max))
399399
) {
400400
self = Raw.statelessRandomNormal(
401401
shape: Tensor<Int32>((0..<shape.rank).map { Int32(shape[$0]) }),
402-
seed: Tensor<Int64>([seed.0, seed.1])
402+
seed: Tensor<Int32>([seed.0, seed.1])
403403
)
404404
}
405405
}
@@ -475,8 +475,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
475475
///
476476
init(
477477
glorotUniform shape: TensorShape,
478-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
479-
Int64.random(in: Int64.min..<Int64.max))
478+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
479+
Int32.random(in: Int32.min..<Int32.max))
480480
) {
481481
let uniform = Tensor(randomUniform: shape, seed: seed)
482482
self = Tensor.glorot(fromStandardUniform: uniform, shape: shape)

Sources/TensorFlow/Layers/Convolutional.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ public extension Conv1D {
116116
stride: Int = 1,
117117
padding: Padding = .valid,
118118
activation: @escaping Activation = identity,
119-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
120-
Int64.random(in: Int64.min..<Int64.max))
119+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
120+
Int32.random(in: Int32.min..<Int32.max))
121121
) {
122122
let filterTensorShape = TensorShape([
123123
filterShape.0, filterShape.1, filterShape.2])
@@ -232,8 +232,8 @@ public extension Conv2D {
232232
strides: (Int, Int) = (1, 1),
233233
padding: Padding = .valid,
234234
activation: @escaping Activation = identity,
235-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
236-
Int64.random(in: Int64.min..<Int64.max))
235+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
236+
Int32.random(in: Int32.min..<Int32.max))
237237
) {
238238
let filterTensorShape = TensorShape([
239239
filterShape.0, filterShape.1, filterShape.2, filterShape.3])
@@ -348,8 +348,8 @@ public extension Conv3D {
348348
strides: (Int, Int, Int) = (1, 1, 1),
349349
padding: Padding = .valid,
350350
activation: @escaping Activation = identity,
351-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
352-
Int64.random(in: Int64.min..<Int64.max))
351+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
352+
Int32.random(in: Int32.min..<Int32.max))
353353
) {
354354
let filterTensorShape = TensorShape([
355355
filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4])
@@ -473,8 +473,8 @@ public extension TransposedConv2D {
473473
strides: (Int, Int) = (1, 1),
474474
padding: Padding = .valid,
475475
activation: @escaping Activation = identity,
476-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
477-
Int64.random(in: Int64.min..<Int64.max))
476+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
477+
Int32.random(in: Int32.min..<Int32.max))
478478
) {
479479
let filterTensorShape = TensorShape([
480480
filterShape.0, filterShape.1, filterShape.2, filterShape.3])

Sources/TensorFlow/Layers/Core.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ public extension Dense {
214214
inputSize: Int,
215215
outputSize: Int,
216216
activation: @escaping Activation = identity,
217-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
218-
Int64.random(in: Int64.min..<Int64.max))
217+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
218+
Int32.random(in: Int32.min..<Int32.max))
219219
) {
220220
self.init(weight: Tensor(glorotUniform: [inputSize, outputSize],
221221
seed: seed),

Sources/TensorFlow/Layers/Recurrent.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ public struct SimpleRNNCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNum
101101
/// - hiddenSize: The number of features in 2-D hidden states.
102102
/// - seed: The random seed for initialization. The default value is random.
103103
public init(inputSize: Int, hiddenSize: Int,
104-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
105-
Int64.random(in: Int64.min..<Int64.max))) {
104+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
105+
Int32.random(in: Int32.min..<Int32.max))) {
106106
let concatenatedInputSize = inputSize + hiddenSize
107107
self.weight = Tensor(glorotUniform: [concatenatedInputSize, hiddenSize], seed: seed)
108108
self.bias = Tensor(zeros: [hiddenSize])
@@ -144,8 +144,8 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell, VectorNumeric
144144
/// - inputSize: The number of features in 2-D input tensors.
145145
/// - hiddenSize: The number of features in 2-D hidden states.
146146
public init(inputSize: Int, hiddenSize: Int,
147-
seed: (Int64, Int64) = (Int64.random(in: Int64.min..<Int64.max),
148-
Int64.random(in: Int64.min..<Int64.max))) {
147+
seed: (Int32, Int32) = (Int32.random(in: Int32.min..<Int32.max),
148+
Int32.random(in: Int32.min..<Int32.max))) {
149149
let concatenatedInputSize = inputSize + hiddenSize
150150
let gateWeightShape = TensorShape([concatenatedInputSize, hiddenSize])
151151
let gateBiasShape = TensorShape([hiddenSize])

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,15 @@ final class LayerTests: XCTestCase {
214214
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
215215
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
216216
let rnn = RNN(SimpleRNNCell<Float>(inputSize: 4, hiddenSize: 4,
217-
seed: (0xFeedBeef, 0xDeadBeef)))
217+
seed: (0xFeed, 0xBeef)))
218218
let (outputs, _) = rnn.valueWithPullback(at: inputs) { rnn, inputs in
219219
return rnn(inputs)
220220
}
221221
XCTAssertEqual(outputs.map { $0.value },
222-
[[[ -0.00262943, -0.005866742, 0.044919778, 0.20036437]],
223-
[[ 0.066890605, 0.049586136, 0.024610005, 0.09341654]],
224-
[[ 0.065792546, 0.009325638, 0.06439907, 0.114802904]],
225-
[[ 0.055909205, 0.00035158166, 0.054020774, 0.09812111]]])
222+
[[[ 0.20775771, 0.20080023, -0.13768704, -0.18534681]],
223+
[[ 0.22666009, 0.30019346, -0.19720285, -0.14683801]],
224+
[[ 0.23758979, 0.32101023, -0.20359215, -0.1787096]],
225+
[[ 0.24337786, 0.3389194, -0.21143384, -0.1675081]]])
226226
// TODO: Figure out why the following is numerically unstable.
227227
// let (𝛁rnn, _) = pullback(.init(inputs.map { SimpleRNNCell<Float>.State($0) }))
228228
// XCTAssertEqual(𝛁rnn.cell.weight,

Tests/TensorFlowTests/SequentialTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ final class SequentialTests: XCTestCase {
2121
var dense1 = Dense<Float>(inputSize: 2, outputSize: 4, activation: relu,
2222
seed: (0xfffffff, 0xfeeff))
2323
var dense2 = Dense<Float>(inputSize: 4, outputSize: 1, activation: relu,
24-
seed: (0xfeffeffe, 0xfffe))
24+
seed: (0xeffeffe, 0xfffe))
2525

2626
@differentiable
2727
func call(_ input: Tensor<Float>) -> Tensor<Float> {
@@ -41,7 +41,7 @@ final class SequentialTests: XCTestCase {
4141
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
4242
}
4343
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
44-
[[ 0.491493], [ 0.5063815], [0.49968663], [0.50133944]])
44+
[[ 0.4904838], [0.49942452], [0.49740878], [ 0.5106092]])
4545
}
4646

4747
static var allTests = [

Tests/TensorFlowTests/TrivialModelTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ final class TrivialModelTests: XCTestCase {
3030
inputSize: hiddenSize,
3131
outputSize: 1,
3232
activation: relu,
33-
seed: (0xfeffeffe, 0xfffe)
33+
seed: (0xffeffe, 0xfffe)
3434
)
3535
}
3636
@differentiable

0 commit comments

Comments
 (0)