This repository was archived by the owner on Apr 23, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 149
Fast style transfer example #191
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d40df94
Add fast style transfer example
vvmnnnkv 31d21a8
Fix code formatting
vvmnnnkv aef87e3
Remove Python/Numpy dependency for weights loading
vvmnnnkv 29e9f88
Apply swift-format tool
vvmnnnkv 92c73ba
Merge branch 'master' into fast-style
vvmnnnkv 1739246
Minor formatting fix
vvmnnnkv b397b2f
Fix mac compilation error; wording fix
vvmnnnkv df1665d
Merge branch 'master' into fast-style
BradLarson c3ba5b0
Merge remote-tracking branch 'upstream/master' into fast-style
5f770bd
Update code to work with s4tf v0.6
vvmnnnkv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ | |
cifar-10-batches-py/ | ||
cifar-10-batches-bin/ | ||
output/ | ||
.idea |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import Foundation | ||
import TensorFlow | ||
import FastStyleTransfer | ||
|
||
extension TransformerNet: ImportableLayer {} | ||
|
||
/// Updates `obj` with values from command line arguments according to `params` map. | ||
func parseArguments<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) { | ||
for arg in CommandLine.arguments.dropFirst() { | ||
if !arg.starts(with: "--") { continue } | ||
let parts = arg.split(separator: "=", maxSplits: 2) | ||
let name = String(parts[0][parts[0].index(parts[0].startIndex, offsetBy: 2)...]) | ||
if let path = params[name], parts.count == 2 { | ||
obj[keyPath: path] = String(parts[1]) | ||
} | ||
} | ||
} | ||
|
||
enum FileError: Error { | ||
case fileNotFound | ||
} | ||
|
||
/// Updates `model` with parameters from V2 checkpoint in `path`. | ||
func importWeights(_ model: inout TransformerNet, from path: String) throws { | ||
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else { | ||
throw FileError.fileNotFound | ||
} | ||
// Names don't match exactly, and axes in filters need to be reversed. | ||
let map = [ | ||
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"conv3.conv2d.filter": ("conv3.conv2d.weight", [3, 2, 1, 0]), | ||
"deconv1.conv2d.filter": ("deconv1.conv2d.weight", [3, 2, 1, 0]), | ||
"deconv2.conv2d.filter": ("deconv2.conv2d.weight", [3, 2, 1, 0]), | ||
"deconv3.conv2d.filter": ("deconv3.conv2d.weight", [3, 2, 1, 0]), | ||
"res1.conv1.conv2d.filter": ("res1.conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"res1.conv2.conv2d.filter": ("res1.conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"res1.in1.scale": ("res1.in1.weight", nil), | ||
"res1.in1.offset": ("res1.in1.bias", nil), | ||
"res1.in2.scale": ("res1.in2.weight", nil), | ||
"res1.in2.offset": ("res1.in2.bias", nil), | ||
"res2.conv1.conv2d.filter": ("res2.conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"res2.conv2.conv2d.filter": ("res2.conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"res2.in1.scale": ("res2.in1.weight", nil), | ||
"res2.in1.offset": ("res2.in1.bias", nil), | ||
"res2.in2.scale": ("res2.in2.weight", nil), | ||
"res2.in2.offset": ("res2.in2.bias", nil), | ||
"res3.conv1.conv2d.filter": ("res3.conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"res3.conv2.conv2d.filter": ("res3.conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"res3.in1.scale": ("res3.in1.weight", nil), | ||
"res3.in1.offset": ("res3.in1.bias", nil), | ||
"res3.in2.scale": ("res3.in2.weight", nil), | ||
"res3.in2.offset": ("res3.in2.bias", nil), | ||
"res4.conv1.conv2d.filter": ("res4.conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"res4.conv2.conv2d.filter": ("res4.conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"res4.in1.scale": ("res4.in1.weight", nil), | ||
"res4.in1.offset": ("res4.in1.bias", nil), | ||
"res4.in2.scale": ("res4.in2.weight", nil), | ||
"res4.in2.offset": ("res4.in2.bias", nil), | ||
"res5.conv1.conv2d.filter": ("res5.conv1.conv2d.weight", [3, 2, 1, 0]), | ||
"res5.conv2.conv2d.filter": ("res5.conv2.conv2d.weight", [3, 2, 1, 0]), | ||
"res5.in1.scale": ("res5.in1.weight", nil), | ||
"res5.in1.offset": ("res5.in1.bias", nil), | ||
"res5.in2.scale": ("res5.in2.weight", nil), | ||
"res5.in2.offset": ("res5.in2.bias", nil), | ||
"in1.scale": ("in1.weight", nil), | ||
"in1.offset": ("in1.bias", nil), | ||
"in2.scale": ("in2.weight", nil), | ||
"in2.offset": ("in2.bias", nil), | ||
"in3.scale": ("in3.weight", nil), | ||
"in3.offset": ("in3.bias", nil), | ||
"in4.scale": ("in4.weight", nil), | ||
"in4.offset": ("in4.bias", nil), | ||
"in5.scale": ("in5.weight", nil), | ||
"in5.offset": ("in5.bias", nil), | ||
] | ||
model.unsafeImport(fromCheckpointPath: path, map: map) | ||
} | ||
|
||
/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range. | ||
func loadJpegAsTensor(from file: String) throws -> Tensor<Float> { | ||
guard FileManager.default.fileExists(atPath: file) else { | ||
throw FileError.fileNotFound | ||
} | ||
let imgData = Raw.readFile(filename: StringTensor(file)) | ||
return Tensor<Float>(Raw.decodeJpeg(contents: imgData, channels: 3, dctMethod: "")) / 255 | ||
} | ||
|
||
/// Clips & converts HxWxC `tensor` of floats to byte range and saves as JPEG. | ||
func saveTensorAsJpeg(_ tensor: Tensor<Float>, to file: String) { | ||
let clipped = Raw.clipByValue(t: tensor, clipValueMin: Tensor(0), clipValueMax: Tensor(255)) | ||
let jpg = Raw.encodeJpeg(image: Tensor<UInt8>(clipped), format: .rgb, xmpMetadata: "") | ||
Raw.writeFile(filename: StringTensor(file), contents: jpg) | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import Foundation | ||
import TensorFlow | ||
import FastStyleTransfer | ||
|
||
func printUsage() { | ||
let exec = URL(string: CommandLine.arguments[0])!.lastPathComponent | ||
print("Usage:") | ||
print("\(exec) --weights=<path> --image=<path> --output=<path>") | ||
print(" --weights: Path to weights in TF checkpoint V2 format") | ||
print(" --image: Path to image in JPEG format") | ||
print(" --output: Path to output image") | ||
} | ||
|
||
/// Startup parameters. | ||
struct Config { | ||
var weights: String? = "FastStyleTransfer/Demo/weights/candy" | ||
var image: String? = nil | ||
var output: String? = "out.jpg" | ||
} | ||
|
||
var config = Config() | ||
parseArguments( | ||
into: &config, | ||
with: [ | ||
"weights": \Config.weights, | ||
"image": \Config.image, | ||
"output": \Config.output | ||
] | ||
) | ||
|
||
guard let image = config.image, let output = config.output else { | ||
print("Error: No input image!") | ||
printUsage() | ||
exit(1) | ||
} | ||
|
||
guard let imageTensor = try? loadJpegAsTensor(from: image) else { | ||
print("Error: Failed to load image \(image). Check file exists and has JPEG format") | ||
printUsage() | ||
exit(1) | ||
} | ||
|
||
// Init the model. | ||
var style = TransformerNet() | ||
do { | ||
try importWeights(&style, from: config.weights!) | ||
} catch { | ||
print("Error: Failed to load weights \(config.weights!). Check path exists and contains TF checkpoint") | ||
printUsage() | ||
exit(1) | ||
} | ||
|
||
// Apply the model to image. | ||
let out = style(imageTensor.expandingShape(at: 0)) | ||
|
||
saveTensorAsJpeg(out.squeezingShape(at: 0), to: output) | ||
print("Written output to \(output)") |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import sys | ||
import torch | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
# Usage: | ||
# python torch-convert.py model.pth model | ||
# (produces tensorflow checkpoint model.*) | ||
|
||
if __name__ == "__main__": | ||
in_file, out_file = sys.argv[1], sys.argv[2] | ||
state_dict = torch.load(in_file) | ||
variables = {} | ||
tf.reset_default_graph() | ||
for label, tensor in state_dict.items(): | ||
variables[label] = tf.get_variable(label, initializer=tensor.numpy()) | ||
|
||
init_op = tf.global_variables_initializer() | ||
saver = tf.train.Saver() | ||
with tf.Session() as sess: | ||
sess.run(init_op) | ||
save_path = saver.save(sess, out_file) |
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import TensorFlow | ||
|
||
/// A 2-D layer applying padding with reflection over a mini-batch. | ||
public struct ReflectionPad2D<Scalar: TensorFlowFloatingPoint>: Layer { | ||
/// The padding values along the spatial dimensions. | ||
@noDerivative public let padding: ((Int, Int), (Int, Int)) | ||
|
||
/// Creates a reflect-padding 2D Layer. | ||
/// | ||
/// - Parameter padding: A tuple of 2 tuples of two integers describing how many elements to | ||
/// be padded at the beginning and end of each padding dimensions. | ||
public init(padding: ((Int, Int), (Int, Int))) { | ||
self.padding = padding | ||
} | ||
|
||
/// Creates a reflect-padding 2D Layer. | ||
/// | ||
/// - Parameter padding: Integer that describes how many elements to be padded | ||
/// at the beginning and end of each padding dimensions. | ||
public init(padding: Int) { | ||
self.padding = ((padding, padding), (padding, padding)) | ||
} | ||
|
||
/// Returns the output obtained from applying the layer to the given input. | ||
/// | ||
/// - Parameter input: The input to the layer. Expected layout is BxHxWxC. | ||
/// - Returns: The output. | ||
@differentiable | ||
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> { | ||
// Padding applied to height and width dimensions only. | ||
return input.paddedWithReflection(forSizes: [ | ||
(0, 0), | ||
padding.0, | ||
padding.1, | ||
(0, 0) | ||
]) | ||
} | ||
} | ||
|
||
/// A layer applying `relu` activation function. | ||
public struct ReLU<Scalar: TensorFlowFloatingPoint>: Layer { | ||
/// Returns the output obtained from applying the layer to the given input. | ||
/// | ||
/// - Parameter input: The input to the layer. | ||
/// - Returns: The output. | ||
@differentiable | ||
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> { | ||
return relu(input) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import TensorFlow | ||
|
||
/// 2-D layer applying instance normalization over a mini-batch of inputs. | ||
/// | ||
/// Reference: [Instance Normalization](https://arxiv.org/abs/1607.08022) | ||
public struct InstanceNorm2D<Scalar: TensorFlowFloatingPoint>: Layer { | ||
/// Learnable parameter scale for affine transformation. | ||
public var scale: Tensor<Scalar> | ||
/// Learnable parameter offset for affine transformation. | ||
public var offset: Tensor<Scalar> | ||
/// Small value added in denominator for numerical stability. | ||
@noDerivative public var epsilon: Tensor<Scalar> | ||
|
||
/// Creates a instance normalization 2D Layer. | ||
/// | ||
/// - Parameters: | ||
/// - featureCount: Size of the channel axis in the expected input. | ||
/// - epsilon: Small scalar added for numerical stability. | ||
public init(featureCount: Int, epsilon: Tensor<Scalar> = Tensor(1e-5)) { | ||
self.epsilon = epsilon | ||
scale = Tensor<Scalar>(ones: [featureCount]) | ||
offset = Tensor<Scalar>(zeros: [featureCount]) | ||
} | ||
|
||
/// Returns the output obtained from applying the layer to the given input. | ||
/// | ||
/// - Parameter input: The input to the layer. Expected input layout is BxHxWxC. | ||
/// - Returns: The output. | ||
@differentiable | ||
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> { | ||
// Calculate mean & variance along H,W axes. | ||
let mean = input.mean(alongAxes: [1, 2]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
let variance = input.variance(alongAxes: [1, 2]) | ||
let norm = (input - mean) * rsqrt(variance + epsilon) | ||
return norm * scale + offset | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding
InstanceNorm2D
to swift-apis?