This repository was archived by the owner on Apr 23, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 3 files changed +8
-13
lines changed Expand file tree Collapse file tree 3 files changed +8
-13
lines changed Original file line number Diff line number Diff line change @@ -28,12 +28,9 @@ public class PythonCheckpointReader {
28
28
let countSuffix = layerCounts [ layerName] == nil ? " " : " _ \( layerCounts [ layerName] !) "
29
29
let tensorName = layerName + countSuffix + " / " + weightName
30
30
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
31
- return Tensor < Float > ( handle: #tfop(
32
- " RestoreV2 " ,
33
- StringTensor ( path) ,
34
- StringTensor ( [ tensorName] ) ,
35
- StringTensor ( [ " " ] ) ,
36
- dtypes$dtype: [ Float . tensorFlowDataType] ) )
31
+ return Raw . restoreV2 ( prefix: StringTensor ( path) ,
32
+ tensorNames: StringTensor ( [ tensorName] ) ,
33
+ shapeAndSlices: StringTensor ( [ " " ] ) )
37
34
}
38
35
39
36
/// Increments a per-layer counter for variable names in the checkpoint file.
Original file line number Diff line number Diff line change @@ -19,7 +19,8 @@ import TensorFlow
19
19
/// Computes the Gaussian error linear unit (GELU) nonlinear activation function
20
20
@differentiable
21
21
func gelu< Scalar: TensorFlowFloatingPoint > ( _ x: Tensor < Scalar > ) -> Tensor < Scalar > {
22
- let polynomial = 0.79788456 * ( x + 0.044715 * x * x * x)
22
+ let xCubed = x * x * x
23
+ let polynomial = 0.79788456 * ( x + 0.044715 * xCubed)
23
24
return 0.5 * x * ( 1.0 + tanh( polynomial) )
24
25
}
25
26
Original file line number Diff line number Diff line change @@ -36,12 +36,9 @@ func readTensor<Scalar: TensorFlowScalar>(
36
36
scalarType: Scalar . Type
37
37
) -> Tensor < Scalar > {
38
38
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
39
- return Tensor ( handle: #tfop(
40
- " RestoreV2 " ,
41
- StringTensor ( path) ,
42
- StringTensor ( [ name] ) ,
43
- StringTensor ( [ " " ] ) ,
44
- dtypes$dtype: [ Scalar . tensorFlowDataType] ) )
39
+ return Raw . restoreV2 ( prefix: StringTensor ( path) ,
40
+ tensorNames: StringTensor ( [ name] ) ,
41
+ shapeAndSlices: StringTensor ( [ " " ] ) )
45
42
}
46
43
47
44
protocol InitializableFromPythonCheckpoint {
You can’t perform that action at this time.
0 commit comments