Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 030153c

Browse files
authored
A few quick fixes to help unbreak swift-models. (#160)
1 parent ad63e2b commit 030153c

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

MiniGo/Models/PythonCheckpointReader.swift

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,9 @@ public class PythonCheckpointReader {
2828
let countSuffix = layerCounts[layerName] == nil ? "" : "_\(layerCounts[layerName]!)"
2929
let tensorName = layerName + countSuffix + "/" + weightName
3030
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
31-
return Tensor<Float>(handle: #tfop(
32-
"RestoreV2",
33-
StringTensor(path),
34-
StringTensor([tensorName]),
35-
StringTensor([""]),
36-
dtypes$dtype: [Float.tensorFlowDataType]))
31+
return Raw.restoreV2(prefix: StringTensor(path),
32+
tensorNames: StringTensor([tensorName]),
33+
shapeAndSlices: StringTensor([""]))
3734
}
3835

3936
/// Increments a per-layer counter for variable names in the checkpoint file.

Transformer/Operators.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ import TensorFlow
1919
/// Computes the Gaussian error linear unit (GELU) nonlinear activation function
2020
@differentiable
2121
func gelu<Scalar: TensorFlowFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
22-
let polynomial = 0.79788456 * (x + 0.044715 * x * x * x)
22+
let xCubed = x * x * x
23+
let polynomial = 0.79788456 * (x + 0.044715 * xCubed)
2324
return 0.5 * x * (1.0 + tanh(polynomial))
2425
}
2526

Transformer/PythonCheckpointReader.swift

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ func readTensor<Scalar: TensorFlowScalar>(
3636
scalarType: Scalar.Type
3737
) -> Tensor<Scalar> {
3838
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
39-
return Tensor(handle: #tfop(
40-
"RestoreV2",
41-
StringTensor(path),
42-
StringTensor([name]),
43-
StringTensor([""]),
44-
dtypes$dtype: [Scalar.tensorFlowDataType]))
39+
return Raw.restoreV2(prefix: StringTensor(path),
40+
tensorNames: StringTensor([name]),
41+
shapeAndSlices: StringTensor([""]))
4542
}
4643

4744
protocol InitializableFromPythonCheckpoint {

0 commit comments

Comments
 (0)