Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Fast style transfer example #191

Merged
merged 10 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
cifar-10-batches-py/
cifar-10-batches-bin/
output/
.idea
352 changes: 352 additions & 0 deletions FastStyleTransfer/Demo/ColabDemo.ipynb

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions FastStyleTransfer/Demo/Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import Foundation
import TensorFlow
import FastStyleTransfer

extension TransformerNet: ImportableLayer {}

/// Updates `obj` with values from command line arguments according to `params` map.
func parseArguments<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) {
for arg in CommandLine.arguments.dropFirst() {
if !arg.starts(with: "--") { continue }
let parts = arg.split(separator: "=", maxSplits: 2)
let name = String(parts[0][parts[0].index(parts[0].startIndex, offsetBy: 2)...])
if let path = params[name], parts.count == 2 {
obj[keyPath: path] = String(parts[1])
}
}
}

enum FileError: Error {
case fileNotFound
}

/// Updates `model` with parameters from V2 checkpoint in `path`.
func importWeights(_ model: inout TransformerNet, from path: String) throws {
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else {
throw FileError.fileNotFound
}
// Names don't match exactly, and axes in filters need to be reversed.
let map = [
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]),
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]),
"conv3.conv2d.filter": ("conv3.conv2d.weight", [3, 2, 1, 0]),
"deconv1.conv2d.filter": ("deconv1.conv2d.weight", [3, 2, 1, 0]),
"deconv2.conv2d.filter": ("deconv2.conv2d.weight", [3, 2, 1, 0]),
"deconv3.conv2d.filter": ("deconv3.conv2d.weight", [3, 2, 1, 0]),
"res1.conv1.conv2d.filter": ("res1.conv1.conv2d.weight", [3, 2, 1, 0]),
"res1.conv2.conv2d.filter": ("res1.conv2.conv2d.weight", [3, 2, 1, 0]),
"res1.in1.scale": ("res1.in1.weight", nil),
"res1.in1.offset": ("res1.in1.bias", nil),
"res1.in2.scale": ("res1.in2.weight", nil),
"res1.in2.offset": ("res1.in2.bias", nil),
"res2.conv1.conv2d.filter": ("res2.conv1.conv2d.weight", [3, 2, 1, 0]),
"res2.conv2.conv2d.filter": ("res2.conv2.conv2d.weight", [3, 2, 1, 0]),
"res2.in1.scale": ("res2.in1.weight", nil),
"res2.in1.offset": ("res2.in1.bias", nil),
"res2.in2.scale": ("res2.in2.weight", nil),
"res2.in2.offset": ("res2.in2.bias", nil),
"res3.conv1.conv2d.filter": ("res3.conv1.conv2d.weight", [3, 2, 1, 0]),
"res3.conv2.conv2d.filter": ("res3.conv2.conv2d.weight", [3, 2, 1, 0]),
"res3.in1.scale": ("res3.in1.weight", nil),
"res3.in1.offset": ("res3.in1.bias", nil),
"res3.in2.scale": ("res3.in2.weight", nil),
"res3.in2.offset": ("res3.in2.bias", nil),
"res4.conv1.conv2d.filter": ("res4.conv1.conv2d.weight", [3, 2, 1, 0]),
"res4.conv2.conv2d.filter": ("res4.conv2.conv2d.weight", [3, 2, 1, 0]),
"res4.in1.scale": ("res4.in1.weight", nil),
"res4.in1.offset": ("res4.in1.bias", nil),
"res4.in2.scale": ("res4.in2.weight", nil),
"res4.in2.offset": ("res4.in2.bias", nil),
"res5.conv1.conv2d.filter": ("res5.conv1.conv2d.weight", [3, 2, 1, 0]),
"res5.conv2.conv2d.filter": ("res5.conv2.conv2d.weight", [3, 2, 1, 0]),
"res5.in1.scale": ("res5.in1.weight", nil),
"res5.in1.offset": ("res5.in1.bias", nil),
"res5.in2.scale": ("res5.in2.weight", nil),
"res5.in2.offset": ("res5.in2.bias", nil),
"in1.scale": ("in1.weight", nil),
"in1.offset": ("in1.bias", nil),
"in2.scale": ("in2.weight", nil),
"in2.offset": ("in2.bias", nil),
"in3.scale": ("in3.weight", nil),
"in3.offset": ("in3.bias", nil),
"in4.scale": ("in4.weight", nil),
"in4.offset": ("in4.bias", nil),
"in5.scale": ("in5.weight", nil),
"in5.offset": ("in5.bias", nil),
]
model.unsafeImport(fromCheckpointPath: path, map: map)
}

/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range.
func loadJpegAsTensor(from file: String) throws -> Tensor<Float> {
guard FileManager.default.fileExists(atPath: file) else {
throw FileError.fileNotFound
}
let imgData = Raw.readFile(filename: StringTensor(file))
return Tensor<Float>(Raw.decodeJpeg(contents: imgData, channels: 3, dctMethod: "")) / 255
}

/// Clips & converts HxWxC `tensor` of floats to byte range and saves as JPEG.
func saveTensorAsJpeg(_ tensor: Tensor<Float>, to file: String) {
let clipped = Raw.clipByValue(t: tensor, clipValueMin: Tensor(0), clipValueMax: Tensor(255))
let jpg = Raw.encodeJpeg(image: Tensor<UInt8>(clipped), format: .rgb, xmpMetadata: "")
Raw.writeFile(filename: StringTensor(file), contents: jpg)
}
Binary file added FastStyleTransfer/Demo/examples/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added FastStyleTransfer/Demo/examples/cat_candy.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added FastStyleTransfer/Demo/examples/cat_mosaic.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
57 changes: 57 additions & 0 deletions FastStyleTransfer/Demo/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import Foundation
import TensorFlow
import FastStyleTransfer

func printUsage() {
let exec = URL(string: CommandLine.arguments[0])!.lastPathComponent
print("Usage:")
print("\(exec) --weights=<path> --image=<path> --output=<path>")
print(" --weights: Path to weights in TF checkpoint V2 format")
print(" --image: Path to image in JPEG format")
print(" --output: Path to output image")
}

/// Startup parameters.
struct Config {
var weights: String? = "FastStyleTransfer/Demo/weights/candy"
var image: String? = nil
var output: String? = "out.jpg"
}

var config = Config()
parseArguments(
into: &config,
with: [
"weights": \Config.weights,
"image": \Config.image,
"output": \Config.output
]
)

guard let image = config.image, let output = config.output else {
print("Error: No input image!")
printUsage()
exit(1)
}

guard let imageTensor = try? loadJpegAsTensor(from: image) else {
print("Error: Failed to load image \(image). Check file exists and has JPEG format")
printUsage()
exit(1)
}

// Init the model.
var style = TransformerNet()
do {
try importWeights(&style, from: config.weights!)
} catch {
print("Error: Failed to load weights \(config.weights!). Check path exists and contains TF checkpoint")
printUsage()
exit(1)
}

// Apply the model to image.
let out = style(imageTensor.expandingShape(at: 0))

saveTensorAsJpeg(out.squeezingShape(at: 0), to: output)
print("Written output to \(output)")
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/candy.index
Binary file not shown.
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/mosaic.index
Binary file not shown.
22 changes: 22 additions & 0 deletions FastStyleTransfer/Demo/weights/torch-convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
import torch
import numpy as np
import tensorflow as tf

# Usage:
# python torch-convert.py model.pth model
# (produces tensorflow checkpoint model.*)

if __name__ == "__main__":
in_file, out_file = sys.argv[1], sys.argv[2]
state_dict = torch.load(in_file)
variables = {}
tf.reset_default_graph()
for label, tensor in state_dict.items():
variables[label] = tf.get_variable(label, initializer=tensor.numpy())

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
save_path = saver.save(sess, out_file)
Binary file not shown.
Binary file added FastStyleTransfer/Demo/weights/udnie.index
Binary file not shown.
50 changes: 50 additions & 0 deletions FastStyleTransfer/Layers/Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import TensorFlow

/// A 2-D layer applying padding with reflection over a mini-batch.
public struct ReflectionPad2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The padding values along the spatial dimensions.
@noDerivative public let padding: ((Int, Int), (Int, Int))

/// Creates a reflect-padding 2D Layer.
///
/// - Parameter padding: A tuple of 2 tuples of two integers describing how many elements to
/// be padded at the beginning and end of each padding dimensions.
public init(padding: ((Int, Int), (Int, Int))) {
self.padding = padding
}

/// Creates a reflect-padding 2D Layer.
///
/// - Parameter padding: Integer that describes how many elements to be padded
/// at the beginning and end of each padding dimensions.
public init(padding: Int) {
self.padding = ((padding, padding), (padding, padding))
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer. Expected layout is BxHxWxC.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
// Padding applied to height and width dimensions only.
return input.paddedWithReflection(forSizes: [
(0, 0),
padding.0,
padding.1,
(0, 0)
])
}
}

/// A layer applying `relu` activation function.
public struct ReLU<Scalar: TensorFlowFloatingPoint>: Layer {
/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return relu(input)
}
}
37 changes: 37 additions & 0 deletions FastStyleTransfer/Layers/Normalization.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import TensorFlow

/// 2-D layer applying instance normalization over a mini-batch of inputs.
///
/// Reference: [Instance Normalization](https://arxiv.org/abs/1607.08022)
public struct InstanceNorm2D<Scalar: TensorFlowFloatingPoint>: Layer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding InstanceNorm2D to swift-apis?

/// Learnable parameter scale for affine transformation.
public var scale: Tensor<Scalar>
/// Learnable parameter offset for affine transformation.
public var offset: Tensor<Scalar>
/// Small value added in denominator for numerical stability.
@noDerivative public var epsilon: Tensor<Scalar>

/// Creates a instance normalization 2D Layer.
///
/// - Parameters:
/// - featureCount: Size of the channel axis in the expected input.
/// - epsilon: Small scalar added for numerical stability.
public init(featureCount: Int, epsilon: Tensor<Scalar> = Tensor(1e-5)) {
self.epsilon = epsilon
scale = Tensor<Scalar>(ones: [featureCount])
offset = Tensor<Scalar>(zeros: [featureCount])
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer. Expected input layout is BxHxWxC.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
// Calculate mean & variance along H,W axes.
let mean = input.mean(alongAxes: [1, 2])
Copy link
Contributor

@t-ae t-ae Aug 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moments computes mean and variance simultaneously.
https://github.com/tensorflow/swift-apis/blob/c7595c4b4b0e824cd6d9449e2bed39944fafb472/Sources/TensorFlow/Operators/Math.swift#L2501-L2512
(variance computes mean internally so moments is some more efficient)

let variance = input.variance(alongAxes: [1, 2])
let norm = (input - mean) * rsqrt(variance + epsilon)
return norm * scale + offset
}
}
Loading