Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 25c4452

Browse files
vvmnnnkvBradLarson
authored andcommitted
Fast style transfer example (#191)
* Add fast style transfer example * Fix code formatting * Remove Python/Numpy dependency for weights loading * Apply swift-format tool * Minor formatting fix * Fix mac compilation error; wording fix * Update code to work with s4tf v0.6
1 parent acf4c34 commit 25c4452

24 files changed

+949
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
cifar-10-batches-py/
1010
cifar-10-batches-bin/
1111
output/
12+
.idea
1213
t10k-labels-idx1-ubyte
1314
t10k-images-idx3-ubyte
1415
train-labels-idx1-ubyte

FastStyleTransfer/Demo/ColabDemo.ipynb

Lines changed: 352 additions & 0 deletions
Large diffs are not rendered by default.

FastStyleTransfer/Demo/Helpers.swift

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import Foundation
2+
import TensorFlow
3+
import FastStyleTransfer
4+
5+
extension TransformerNet: ImportableLayer {}
6+
7+
/// Updates `obj` with values from command line arguments according to `params` map.
8+
func parseArguments<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) {
9+
for arg in CommandLine.arguments.dropFirst() {
10+
if !arg.starts(with: "--") { continue }
11+
let parts = arg.split(separator: "=", maxSplits: 2)
12+
let name = String(parts[0][parts[0].index(parts[0].startIndex, offsetBy: 2)...])
13+
if let path = params[name], parts.count == 2 {
14+
obj[keyPath: path] = String(parts[1])
15+
}
16+
}
17+
}
18+
19+
enum FileError: Error {
20+
case fileNotFound
21+
}
22+
23+
/// Updates `model` with parameters from V2 checkpoint in `path`.
24+
func importWeights(_ model: inout TransformerNet, from path: String) throws {
25+
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else {
26+
throw FileError.fileNotFound
27+
}
28+
// Names don't match exactly, and axes in filters need to be reversed.
29+
let map = [
30+
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]),
31+
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]),
32+
"conv3.conv2d.filter": ("conv3.conv2d.weight", [3, 2, 1, 0]),
33+
"deconv1.conv2d.filter": ("deconv1.conv2d.weight", [3, 2, 1, 0]),
34+
"deconv2.conv2d.filter": ("deconv2.conv2d.weight", [3, 2, 1, 0]),
35+
"deconv3.conv2d.filter": ("deconv3.conv2d.weight", [3, 2, 1, 0]),
36+
"res1.conv1.conv2d.filter": ("res1.conv1.conv2d.weight", [3, 2, 1, 0]),
37+
"res1.conv2.conv2d.filter": ("res1.conv2.conv2d.weight", [3, 2, 1, 0]),
38+
"res1.in1.scale": ("res1.in1.weight", nil),
39+
"res1.in1.offset": ("res1.in1.bias", nil),
40+
"res1.in2.scale": ("res1.in2.weight", nil),
41+
"res1.in2.offset": ("res1.in2.bias", nil),
42+
"res2.conv1.conv2d.filter": ("res2.conv1.conv2d.weight", [3, 2, 1, 0]),
43+
"res2.conv2.conv2d.filter": ("res2.conv2.conv2d.weight", [3, 2, 1, 0]),
44+
"res2.in1.scale": ("res2.in1.weight", nil),
45+
"res2.in1.offset": ("res2.in1.bias", nil),
46+
"res2.in2.scale": ("res2.in2.weight", nil),
47+
"res2.in2.offset": ("res2.in2.bias", nil),
48+
"res3.conv1.conv2d.filter": ("res3.conv1.conv2d.weight", [3, 2, 1, 0]),
49+
"res3.conv2.conv2d.filter": ("res3.conv2.conv2d.weight", [3, 2, 1, 0]),
50+
"res3.in1.scale": ("res3.in1.weight", nil),
51+
"res3.in1.offset": ("res3.in1.bias", nil),
52+
"res3.in2.scale": ("res3.in2.weight", nil),
53+
"res3.in2.offset": ("res3.in2.bias", nil),
54+
"res4.conv1.conv2d.filter": ("res4.conv1.conv2d.weight", [3, 2, 1, 0]),
55+
"res4.conv2.conv2d.filter": ("res4.conv2.conv2d.weight", [3, 2, 1, 0]),
56+
"res4.in1.scale": ("res4.in1.weight", nil),
57+
"res4.in1.offset": ("res4.in1.bias", nil),
58+
"res4.in2.scale": ("res4.in2.weight", nil),
59+
"res4.in2.offset": ("res4.in2.bias", nil),
60+
"res5.conv1.conv2d.filter": ("res5.conv1.conv2d.weight", [3, 2, 1, 0]),
61+
"res5.conv2.conv2d.filter": ("res5.conv2.conv2d.weight", [3, 2, 1, 0]),
62+
"res5.in1.scale": ("res5.in1.weight", nil),
63+
"res5.in1.offset": ("res5.in1.bias", nil),
64+
"res5.in2.scale": ("res5.in2.weight", nil),
65+
"res5.in2.offset": ("res5.in2.bias", nil),
66+
"in1.scale": ("in1.weight", nil),
67+
"in1.offset": ("in1.bias", nil),
68+
"in2.scale": ("in2.weight", nil),
69+
"in2.offset": ("in2.bias", nil),
70+
"in3.scale": ("in3.weight", nil),
71+
"in3.offset": ("in3.bias", nil),
72+
"in4.scale": ("in4.weight", nil),
73+
"in4.offset": ("in4.bias", nil),
74+
"in5.scale": ("in5.weight", nil),
75+
"in5.offset": ("in5.bias", nil),
76+
]
77+
model.unsafeImport(fromCheckpointPath: path, map: map)
78+
}
79+
80+
/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range.
81+
func loadJpegAsTensor(from file: String) throws -> Tensor<Float> {
82+
guard FileManager.default.fileExists(atPath: file) else {
83+
throw FileError.fileNotFound
84+
}
85+
let imgData = _Raw.readFile(filename: StringTensor(file))
86+
return Tensor<Float>(_Raw.decodeJpeg(contents: imgData, channels: 3, dctMethod: "")) / 255
87+
}
88+
89+
/// Clips & converts HxWxC `tensor` of floats to byte range and saves as JPEG.
90+
func saveTensorAsJpeg(_ tensor: Tensor<Float>, to file: String) {
91+
let clipped = _Raw.clipByValue(t: tensor, clipValueMin: Tensor(0), clipValueMax: Tensor(255))
92+
let jpg = _Raw.encodeJpeg(image: Tensor<UInt8>(clipped), format: .rgb, xmpMetadata: "")
93+
_Raw.writeFile(filename: StringTensor(file), contents: jpg)
94+
}
19.7 KB
Loading
44.8 KB
Loading
47.7 KB
Loading

FastStyleTransfer/Demo/main.swift

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import Foundation
2+
import TensorFlow
3+
import FastStyleTransfer
4+
5+
func printUsage() {
6+
let exec = URL(string: CommandLine.arguments[0])!.lastPathComponent
7+
print("Usage:")
8+
print("\(exec) --weights=<path> --image=<path> --output=<path>")
9+
print(" --weights: Path to weights in TF checkpoint V2 format")
10+
print(" --image: Path to image in JPEG format")
11+
print(" --output: Path to output image")
12+
}
13+
14+
/// Startup parameters.
15+
struct Config {
16+
var weights: String? = "FastStyleTransfer/Demo/weights/candy"
17+
var image: String? = nil
18+
var output: String? = "out.jpg"
19+
}
20+
21+
var config = Config()
22+
parseArguments(
23+
into: &config,
24+
with: [
25+
"weights": \Config.weights,
26+
"image": \Config.image,
27+
"output": \Config.output
28+
]
29+
)
30+
31+
guard let image = config.image, let output = config.output else {
32+
print("Error: No input image!")
33+
printUsage()
34+
exit(1)
35+
}
36+
37+
guard let imageTensor = try? loadJpegAsTensor(from: image) else {
38+
print("Error: Failed to load image \(image). Check file exists and has JPEG format")
39+
printUsage()
40+
exit(1)
41+
}
42+
43+
// Init the model.
44+
var style = TransformerNet()
45+
do {
46+
try importWeights(&style, from: config.weights!)
47+
} catch {
48+
print("Error: Failed to load weights \(config.weights!). Check path exists and contains TF checkpoint")
49+
printUsage()
50+
exit(1)
51+
}
52+
53+
// Apply the model to image.
54+
let out = style(imageTensor.expandingShape(at: 0))
55+
56+
saveTensorAsJpeg(out.squeezingShape(at: 0), to: output)
57+
print("Written output to \(output)")
Binary file not shown.
3.29 KB
Binary file not shown.
Binary file not shown.
3.29 KB
Binary file not shown.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import sys
2+
import torch
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
# Usage:
7+
# python torch-convert.py model.pth model
8+
# (produces tensorflow checkpoint model.*)
9+
10+
if __name__ == "__main__":
11+
in_file, out_file = sys.argv[1], sys.argv[2]
12+
state_dict = torch.load(in_file)
13+
variables = {}
14+
tf.reset_default_graph()
15+
for label, tensor in state_dict.items():
16+
variables[label] = tf.get_variable(label, initializer=tensor.numpy())
17+
18+
init_op = tf.global_variables_initializer()
19+
saver = tf.train.Saver()
20+
with tf.Session() as sess:
21+
sess.run(init_op)
22+
save_path = saver.save(sess, out_file)
Binary file not shown.
3.29 KB
Binary file not shown.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import TensorFlow
2+
3+
/// A 2-D layer applying padding with reflection over a mini-batch.
4+
public struct ReflectionPad2D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
5+
/// The padding values along the spatial dimensions.
6+
@noDerivative public let padding: ((Int, Int), (Int, Int))
7+
8+
/// Creates a reflect-padding 2D Layer.
9+
///
10+
/// - Parameter padding: A tuple of 2 tuples of two integers describing how many elements to
11+
/// be padded at the beginning and end of each padding dimensions.
12+
public init(padding: ((Int, Int), (Int, Int))) {
13+
self.padding = padding
14+
}
15+
16+
/// Creates a reflect-padding 2D Layer.
17+
///
18+
/// - Parameter padding: Integer that describes how many elements to be padded
19+
/// at the beginning and end of each padding dimensions.
20+
public init(padding: Int) {
21+
self.padding = ((padding, padding), (padding, padding))
22+
}
23+
24+
/// Returns the output obtained from applying the layer to the given input.
25+
///
26+
/// - Parameter input: The input to the layer. Expected layout is BxHxWxC.
27+
/// - Returns: The output.
28+
@differentiable
29+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
30+
// Padding applied to height and width dimensions only.
31+
return input.padded(forSizes: [
32+
(0, 0),
33+
padding.0,
34+
padding.1,
35+
(0, 0)
36+
], mode: .reflect)
37+
}
38+
}
39+
40+
/// A layer applying `relu` activation function.
41+
public struct ReLU<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
42+
/// Returns the output obtained from applying the layer to the given input.
43+
///
44+
/// - Parameter input: The input to the layer.
45+
/// - Returns: The output.
46+
@differentiable
47+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
48+
return relu(input)
49+
}
50+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import TensorFlow
2+
3+
/// 2-D layer applying instance normalization over a mini-batch of inputs.
4+
///
5+
/// Reference: [Instance Normalization](https://arxiv.org/abs/1607.08022)
6+
public struct InstanceNorm2D<Scalar: TensorFlowFloatingPoint>: Layer {
7+
/// Learnable parameter scale for affine transformation.
8+
public var scale: Tensor<Scalar>
9+
/// Learnable parameter offset for affine transformation.
10+
public var offset: Tensor<Scalar>
11+
/// Small value added in denominator for numerical stability.
12+
@noDerivative public var epsilon: Tensor<Scalar>
13+
14+
/// Creates a instance normalization 2D Layer.
15+
///
16+
/// - Parameters:
17+
/// - featureCount: Size of the channel axis in the expected input.
18+
/// - epsilon: Small scalar added for numerical stability.
19+
public init(featureCount: Int, epsilon: Tensor<Scalar> = Tensor(1e-5)) {
20+
self.epsilon = epsilon
21+
scale = Tensor<Scalar>(ones: [featureCount])
22+
offset = Tensor<Scalar>(zeros: [featureCount])
23+
}
24+
25+
/// Returns the output obtained from applying the layer to the given input.
26+
///
27+
/// - Parameter input: The input to the layer. Expected input layout is BxHxWxC.
28+
/// - Returns: The output.
29+
@differentiable
30+
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
31+
// Calculate mean & variance along H,W axes.
32+
let mean = input.mean(alongAxes: [1, 2])
33+
let variance = input.variance(alongAxes: [1, 2])
34+
let norm = (input - mean) * rsqrt(variance + epsilon)
35+
return norm * scale + offset
36+
}
37+
}

0 commit comments

Comments
 (0)