Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Fast style transfer example #191

Merged
merged 10 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cifar-10-batches-py/
cifar-10-batches-bin/
output/
.idea
t10k-labels-idx1-ubyte
t10k-images-idx3-ubyte
train-labels-idx1-ubyte
Expand Down
352 changes: 352 additions & 0 deletions FastStyleTransfer/Demo/ColabDemo.ipynb

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions FastStyleTransfer/Demo/Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import Foundation
import TensorFlow
import FastStyleTransfer

extension TransformerNet: ImportableLayer {}

/// Updates `obj` with values from command line arguments according to `params` map.
func parseArguments<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) {
for arg in CommandLine.arguments.dropFirst() {
if !arg.starts(with: "--") { continue }
let parts = arg.split(separator: "=", maxSplits: 2)
let name = String(parts[0][parts[0].index(parts[0].startIndex, offsetBy: 2)...])
if let path = params[name], parts.count == 2 {
obj[keyPath: path] = String(parts[1])
}
}
}

enum FileError: Error {
case fileNotFound
}

/// Updates `model` with parameters from V2 checkpoint in `path`.
func importWeights(_ model: inout TransformerNet, from path: String) throws {
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else {
throw FileError.fileNotFound
}
// Names don't match exactly, and axes in filters need to be reversed.
let map = [
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]),
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]),
"conv3.conv2d.filter": ("conv3.conv2d.weight", [3, 2, 1, 0]),
"deconv1.conv2d.filter": ("deconv1.conv2d.weight", [3, 2, 1, 0]),
"deconv2.conv2d.filter": ("deconv2.conv2d.weight", [3, 2, 1, 0]),
"deconv3.conv2d.filter": ("deconv3.conv2d.weight", [3, 2, 1, 0]),
"res1.conv1.conv2d.filter": ("res1.conv1.conv2d.weight", [3, 2, 1, 0]),
"res1.conv2.conv2d.filter": ("res1.conv2.conv2d.weight", [3, 2, 1, 0]),
"res1.in1.scale": ("res1.in1.weight", nil),
"res1.in1.offset": ("res1.in1.bias", nil),
"res1.in2.scale": ("res1.in2.weight", nil),
"res1.in2.offset": ("res1.in2.bias", nil),
"res2.conv1.conv2d.filter": ("res2.conv1.conv2d.weight", [3, 2, 1, 0]),
"res2.conv2.conv2d.filter": ("res2.conv2.conv2d.weight", [3, 2, 1, 0]),
"res2.in1.scale": ("res2.in1.weight", nil),
"res2.in1.offset": ("res2.in1.bias", nil),
"res2.in2.scale": ("res2.in2.weight", nil),
"res2.in2.offset": ("res2.in2.bias", nil),
"res3.conv1.conv2d.filter": ("res3.conv1.conv2d.weight", [3, 2, 1, 0]),
"res3.conv2.conv2d.filter": ("res3.conv2.conv2d.weight", [3, 2, 1, 0]),
"res3.in1.scale": ("res3.in1.weight", nil),
"res3.in1.offset": ("res3.in1.bias", nil),
"res3.in2.scale": ("res3.in2.weight", nil),
"res3.in2.offset": ("res3.in2.bias", nil),
"res4.conv1.conv2d.filter": ("res4.conv1.conv2d.weight", [3, 2, 1, 0]),
"res4.conv2.conv2d.filter": ("res4.conv2.conv2d.weight", [3, 2, 1, 0]),
"res4.in1.scale": ("res4.in1.weight", nil),
"res4.in1.offset": ("res4.in1.bias", nil),
"res4.in2.scale": ("res4.in2.weight", nil),
"res4.in2.offset": ("res4.in2.bias", nil),
"res5.conv1.conv2d.filter": ("res5.conv1.conv2d.weight", [3, 2, 1, 0]),
"res5.conv2.conv2d.filter": ("res5.conv2.conv2d.weight", [3, 2, 1, 0]),
"res5.in1.scale": ("res5.in1.weight", nil),
"res5.in1.offset": ("res5.in1.bias", nil),
"res5.in2.scale": ("res5.in2.weight", nil),
"res5.in2.offset": ("res5.in2.bias", nil),
"in1.scale": ("in1.weight", nil),
"in1.offset": ("in1.bias", nil),
"in2.scale": ("in2.weight", nil),
"in2.offset": ("in2.bias", nil),
"in3.scale": ("in3.weight", nil),
"in3.offset": ("in3.bias", nil),
"in4.scale": ("in4.weight", nil),
"in4.offset": ("in4.bias", nil),
"in5.scale": ("in5.weight", nil),
"in5.offset": ("in5.bias", nil),
]
model.unsafeImport(fromCheckpointPath: path, map: map)
}

/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range.
func loadJpegAsTensor(from file: String) throws -> Tensor<Float> {
guard FileManager.default.fileExists(atPath: file) else {
throw FileError.fileNotFound
}
let imgData = _Raw.readFile(filename: StringTensor(file))
return Tensor<Float>(_Raw.decodeJpeg(contents: imgData, channels: 3, dctMethod: "")) / 255
}

/// Clips & converts HxWxC `tensor` of floats to byte range and saves as JPEG.
func saveTensorAsJpeg(_ tensor: Tensor<Float>, to file: String) {
let clipped = _Raw.clipByValue(t: tensor, clipValueMin: Tensor(0), clipValueMax: Tensor(255))
let jpg = _Raw.encodeJpeg(image: Tensor<UInt8>(clipped), format: .rgb, xmpMetadata: "")
_Raw.writeFile(filename: StringTensor(file), contents: jpg)
}
Binary file added FastStyleTransfer/Demo/examples/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added FastStyleTransfer/Demo/examples/cat_candy.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added FastStyleTransfer/Demo/examples/cat_mosaic.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 57 additions & 0 deletions FastStyleTransfer/Demo/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import Foundation
import TensorFlow
import FastStyleTransfer

func printUsage() {
let exec = URL(string: CommandLine.arguments[0])!.lastPathComponent
print("Usage:")
print("\(exec) --weights=<path> --image=<path> --output=<path>")
print(" --weights: Path to weights in TF checkpoint V2 format")
print(" --image: Path to image in JPEG format")
print(" --output: Path to output image")
}

/// Startup parameters.
struct Config {
var weights: String? = "FastStyleTransfer/Demo/weights/candy"
var image: String? = nil
var output: String? = "out.jpg"
}

var config = Config()
parseArguments(
into: &config,
with: [
"weights": \Config.weights,
"image": \Config.image,
"output": \Config.output
]
)

guard let image = config.image, let output = config.output else {
print("Error: No input image!")
printUsage()
exit(1)
}

guard let imageTensor = try? loadJpegAsTensor(from: image) else {
print("Error: Failed to load image \(image). Check file exists and has JPEG format")
printUsage()
exit(1)
}

// Init the model.
var style = TransformerNet()
do {
try importWeights(&style, from: config.weights!)
} catch {
print("Error: Failed to load weights \(config.weights!). Check path exists and contains TF checkpoint")
printUsage()
exit(1)
}

// Apply the model to image.
let out = style(imageTensor.expandingShape(at: 0))

saveTensorAsJpeg(out.squeezingShape(at: 0), to: output)
print("Written output to \(output)")
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/candy.index
Binary file not shown.
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/mosaic.index
Binary file not shown.
22 changes: 22 additions & 0 deletions FastStyleTransfer/Demo/weights/torch-convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
import torch
import numpy as np
import tensorflow as tf

# Usage:
# python torch-convert.py model.pth model
# (produces tensorflow checkpoint model.*)

if __name__ == "__main__":
in_file, out_file = sys.argv[1], sys.argv[2]
state_dict = torch.load(in_file)
variables = {}
tf.reset_default_graph()
for label, tensor in state_dict.items():
variables[label] = tf.get_variable(label, initializer=tensor.numpy())

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, out_file)
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/udnie.index
Binary file not shown.
50 changes: 50 additions & 0 deletions FastStyleTransfer/Layers/Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import TensorFlow

/// A 2-D layer applying padding with reflection over a mini-batch.
public struct ReflectionPad2D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
/// The padding values along the spatial dimensions.
@noDerivative public let padding: ((Int, Int), (Int, Int))

/// Creates a reflect-padding 2D Layer.
///
/// - Parameter padding: A tuple of 2 tuples of two integers describing how many elements to
/// be padded at the beginning and end of each padding dimensions.
public init(padding: ((Int, Int), (Int, Int))) {
self.padding = padding
}

/// Creates a reflect-padding 2D Layer.
///
/// - Parameter padding: Integer that describes how many elements to be padded
/// at the beginning and end of each padding dimensions.
public init(padding: Int) {
self.padding = ((padding, padding), (padding, padding))
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer. Expected layout is BxHxWxC.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
// Padding applied to height and width dimensions only.
return input.padded(forSizes: [
(0, 0),
padding.0,
padding.1,
(0, 0)
], mode: .reflect)
}
}

/// A layer applying `relu` activation function.
public struct ReLU<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer {
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return relu(input)
}
}
37 changes: 37 additions & 0 deletions FastStyleTransfer/Layers/Normalization.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import TensorFlow

/// 2-D layer applying instance normalization over a mini-batch of inputs.
///
/// Reference: [Instance Normalization](https://arxiv.org/abs/1607.08022)
public struct InstanceNorm2D<Scalar: TensorFlowFloatingPoint>: Layer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding InstanceNorm2D to swift-apis?

/// Learnable parameter scale for affine transformation.
public var scale: Tensor<Scalar>
/// Learnable parameter offset for affine transformation.
public var offset: Tensor<Scalar>
/// Small value added in denominator for numerical stability.
@noDerivative public var epsilon: Tensor<Scalar>

/// Creates a instance normalization 2D Layer.
///
/// - Parameters:
/// - featureCount: Size of the channel axis in the expected input.
/// - epsilon: Small scalar added for numerical stability.
public init(featureCount: Int, epsilon: Tensor<Scalar> = Tensor(1e-5)) {
self.epsilon = epsilon
scale = Tensor<Scalar>(ones: [featureCount])
offset = Tensor<Scalar>(zeros: [featureCount])
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer. Expected input layout is BxHxWxC.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
// Calculate mean & variance along H,W axes.
let mean = input.mean(alongAxes: [1, 2])
Copy link
Contributor

@t-ae t-ae Aug 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moments computes mean and variance simultaneously.
https://github.com/tensorflow/swift-apis/blob/c7595c4b4b0e824cd6d9449e2bed39944fafb472/Sources/TensorFlow/Operators/Math.swift#L2501-L2512
(variance computes mean internally so moments is some more efficient)

let variance = input.variance(alongAxes: [1, 2])
let norm = (input - mean) * rsqrt(variance + epsilon)
return norm * scale + offset
}
}
Loading