Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit aef87e3

Browse files
committed
Remove Python/Numpy dependency for weights loading
1 parent 31d21a8 commit aef87e3

File tree

11 files changed

+38
-32
lines changed

11 files changed

+38
-32
lines changed

FastStyleTransfer/Demo/ColabDemo.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,10 @@
277277
"let styles = [\"candy\", \"mosaic\", \"udnie\"] \n",
278278
"for s in styles {\n",
279279
" // Download pre-trained weights\n",
280-
" downloadFile(from: \"https://github.com/vvmnnnkv/swift-models/raw/fast-style/FastStyleTransfer/Demo/weights/\\(s).npz\", to: \"\\(s).npz\")\n",
280+
" downloadFile(from: \"https://github.com/vvmnnnkv/swift-models/raw/fast-style/FastStyleTransfer/Demo/weights/\\(s).data-00000-of-00001\", to: \"\\(s).data-00000-of-00001\")\n",
281+
" downloadFile(from: \"https://github.com/vvmnnnkv/swift-models/raw/fast-style/FastStyleTransfer/Demo/weights/\\(s).index\", to: \"\\(s).index\")\n",
281282
" // Load weights into model\n",
282-
" style.unsafeImport(fromNumpyArchive: \"\\(s).npz\", map: map)\n",
283+
" style.unsafeImport(fromCheckpointPath: \"\\(s)\", map: map)\n",
283284
" // Apply model to image\n",
284285
" let out = style(image.expandingShape(at: 0)) / 255\n",
285286
" show_img(out.squeezingShape(at: 0))\n",

FastStyleTransfer/Demo/Helpers.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ enum FileError: Error {
2020
case fileNotFound
2121
}
2222

23-
/// Updates `model` with parameters from numpy archive file in `path`.
23+
/// Updates `model` with parameters from V2 checkpoint in `path`.
2424
func importWeights(_ model: inout TransformerNet, from path: String) throws {
25-
guard FileManager.default.fileExists(atPath: path) else {
25+
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else {
2626
throw FileError.fileNotFound
2727
}
2828
// Names don't match exactly, and axes in filters need to be reversed.
@@ -74,7 +74,7 @@ func importWeights(_ model: inout TransformerNet, from path: String) throws {
7474
"in5.scale": ("in5.weight", nil),
7575
"in5.offset": ("in5.bias", nil),
7676
]
77-
model.unsafeImport(fromNumpyArchive: path, map: map)
77+
model.unsafeImport(fromCheckpointPath: path, map: map)
7878
}
7979

8080
/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range.
3.29 KB
Binary file not shown.
3.29 KB
Binary file not shown.
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import sys
22
import torch
33
import numpy as np
4+
import tensorflow as tf
45

56
# Usage:
67
# python torch-convert.py model.pth model
7-
# (produces model.npz)
8+
# (produces tensorflow checkpoint model.*)
89

910
if __name__ == "__main__":
10-
in_file, out_file= sys.argv[1], sys.argv[2]
11-
state_dict = torch.load(in_file)
12-
npz = {}
13-
for label, tensor in state_dict.items():
14-
npz[label] = tensor.numpy()
15-
np.savez(out_file, **npz)
11+
in_file, out_file = sys.argv[1], sys.argv[2]
12+
state_dict = torch.load(in_file)
13+
variables = {}
14+
tf.reset_default_graph()
15+
for label, tensor in state_dict.items():
16+
variables[label] = tf.get_variable(label, initializer=tensor.numpy())
17+
18+
init_op = tf.global_variables_initializer()
19+
saver = tf.train.Saver()
20+
with tf.Session() as sess:
21+
sess.run(init_op)
22+
save_path = saver.save(sess, out_file)
3.29 KB
Binary file not shown.

FastStyleTransfer/README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@ The model should be trainable, but so far it's only tested for inference with pr
66
## Example
77
Run demo application to apply styles to jpeg images:
88
```
9-
swift run FastStyleTranserDemo --weights=FastStyleTranser/Demo/weights/candy.npz --input=FastStyleTranser/Demo/examples/cat.jpg --output=candy_cat.jpg
10-
swift run FastStyleTranserDemo --weights=FastStyleTranser/Demo/weights/mosaic.npz --input=FastStyleTranser/Demo/examples/cat.jpg --output=mosaic_cat.jpg
9+
swift run FastStyleTranserDemo --weights=FastStyleTranser/Demo/weights/candy --input=FastStyleTranser/Demo/examples/cat.jpg --output=candy_cat.jpg
10+
swift run FastStyleTranserDemo --weights=FastStyleTranser/Demo/weights/mosaic --input=FastStyleTranser/Demo/examples/cat.jpg --output=mosaic_cat.jpg
1111
```
1212

1313
<img src="Demo/examples/cat.jpg" height="240" width="240" align="left">
1414
<img src="Demo/examples/cat_candy.jpg" height="240" width="240" align="left">
1515
<img src="Demo/examples/cat_mosaic.jpg" height="240" width="240">
1616

17-
## Requirements
18-
Requires Python and NumPy to load weights.
19-
2017
## Jupyter Notebook
2118
Run [demo notebook](Demo/ColabDemo.ipynb) in [Colab](https://colab.research.google.com/github/vvmnnnkv/swift-models/blob/fast-style/FastStyleTransfer/Demo/ColabDemo.ipynb)!

FastStyleTransfer/Utility/ImportableLayer.swift

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import TensorFlow
2-
import Python
3-
let np = Python.import("numpy")
42

53
public protocol ImportableLayer: KeyPathIterable {}
64

@@ -46,14 +44,16 @@ public extension ImportableLayer {
4644
}
4745

4846
/// Updates model parameters with values from `parameters`, according to `ImportMap`.
49-
/// TODO add shapes check
50-
mutating func unsafeImport<T>(parameters: [String: Tensor<T>], map: ImportMap) {
51-
for (label, keyPath) in getRecursiveNamedKeyPaths(ofType: Tensor<T>.self) {
52-
// let shape = self[keyPath: keyPath].shape
47+
mutating func unsafeImport(parameters: [String: Tensor<Float>], map: ImportMap) {
48+
for (label, keyPath) in getRecursiveNamedKeyPaths(ofType: Tensor<Float>.self) {
49+
let shape = self[keyPath: keyPath].shape
5350
if let mapping = map[label], var weights = parameters[mapping.0] {
5451
if let permutes = mapping.1 {
5552
weights = weights.transposed(withPermutations: permutes)
5653
}
54+
if weights.shape != shape {
55+
fatalError("Shapes do not match for \(label): \(shape) vs. \(weights.shape)")
56+
}
5757
self[keyPath: keyPath] = weights
5858
// print("imported \(mapping.0) \(shape) -> \(label) \(weights.shape)")
5959
} else if let weights = parameters[label] {
@@ -65,15 +65,16 @@ public extension ImportableLayer {
6565
}
6666

6767
public extension ImportableLayer {
68-
/// Updates model parameters with values from numpy archive, according to `ImportMap`.
69-
mutating func unsafeImport(fromNumpyArchive file: String, map: ImportMap) {
70-
let data = np.load(file)
71-
var parameters = [String: Tensor<Float>]()
72-
for label in data.files {
73-
if let label = String(label) {
74-
parameters[label] = Tensor<Float>(numpy: data[label])
75-
}
76-
}
68+
/// Updates model parameters with values from V2 checkpoint, according to `ImportMap`.
69+
mutating func unsafeImport(fromCheckpointPath path: String, map: ImportMap) {
70+
let tensorNames = map.values.map { $0.0 }
71+
let tensorValues = Raw.restoreV2(
72+
prefix: StringTensor(path),
73+
tensorNames: StringTensor(tensorNames),
74+
shapeAndSlices: StringTensor(Array(repeating: "", count: tensorNames.count)),
75+
dtypes: Array(repeating: Float.tensorFlowDataType, count: tensorNames.count)
76+
).map { $0 as! Tensor<Float> }
77+
let parameters = Dictionary(uniqueKeysWithValues: zip(tensorNames, tensorValues))
7778
unsafeImport(parameters: parameters, map: map)
7879
}
7980
}

0 commit comments

Comments
 (0)