Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 31d21a8

Browse files
committed
Fix code formatting
1 parent d40df94 commit 31d21a8

File tree

8 files changed

+204
-94
lines changed

8 files changed

+204
-94
lines changed

FastStyleTransfer/Demo/Helpers.swift

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import Foundation
22
import TensorFlow
33
import FastStyleTransfer
44

5-
// Make model importable
65
extension TransformerNet: ImportableLayer {}
76

8-
func parseArgs<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) {
7+
/// Updates `obj` with values from command line arguments according to `params` map.
8+
func parseArguments<T>(into obj: inout T, with params: [String: WritableKeyPath<T, String?>]) {
99
for arg in CommandLine.arguments.dropFirst() {
1010
if !arg.starts(with: "--") { continue }
1111
let parts = arg.split(separator: "=", maxSplits: 2)
@@ -17,15 +17,15 @@ func parseArgs<T>(into obj: inout T, with params: [String: WritableKeyPath<T, St
1717
}
1818

1919
enum FileError: Error {
20-
case file_not_found
20+
case fileNotFound
2121
}
2222

23+
/// Updates `model` with parameters from numpy archive file in `path`.
2324
func importWeights(_ model: inout TransformerNet, from path: String) throws {
2425
guard FileManager.default.fileExists(atPath: path) else {
25-
throw FileError.file_not_found
26+
throw FileError.fileNotFound
2627
}
27-
// Map of model params to loaded params
28-
// Names don't match exactly, and axes in filters need to be reversed
28+
// Names don't match exactly, and axes in filters need to be reversed.
2929
let map = [
3030
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]),
3131
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]),
@@ -77,14 +77,16 @@ func importWeights(_ model: inout TransformerNet, from path: String) throws {
7777
model.unsafeImport(fromNumpyArchive: path, map: map)
7878
}
7979

80+
/// Loads from `file` and returns JPEG image as HxWxC tensor of floats in (0..1) range.
8081
func loadJpegAsTensor(from file: String) throws -> Tensor<Float> {
8182
guard FileManager.default.fileExists(atPath: file) else {
82-
throw FileError.file_not_found
83+
throw FileError.fileNotFound
8384
}
8485
let imgData = Raw.readFile(filename: StringTensor(file))
8586
return Tensor<Float>(Raw.decodeJpeg(contents: imgData, channels: 3, dctMethod: "")) / 255
8687
}
8788

89+
/// Clips & converts HxWxC `tensor` of floats to byte range and saves as JPEG.
8890
func saveTensorAsJpeg(_ tensor: Tensor<Float>, to file: String) {
8991
let clipped = Raw.clipByValue(t: tensor, clipValueMin: Tensor(0), clipValueMax: Tensor(255))
9092
let jpg = Raw.encodeJpeg(image: Tensor<UInt8>(clipped), format: .rgb, xmpMetadata: "")

FastStyleTransfer/Demo/main.swift

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,33 @@ func printUsage() {
1111
print(" --output: Path to output image")
1212
}
1313

14+
/// Startup parameters.
1415
struct Config {
1516
var weights: String? = "Demo/weights/candy.npz"
1617
var image: String? = nil
1718
var output: String? = "out.jpg"
1819
}
20+
1921
var config = Config()
20-
parseArgs(into: &config, with: [
22+
parseArguments(into: &config, with: [
2123
"weights": \Config.weights,
2224
"image": \Config.image,
23-
"output": \Config.output])
25+
"output": \Config.output
26+
])
2427

2528
guard let image = config.image, let output = config.output else {
2629
print("Error: No input image!")
2730
printUsage()
2831
exit(1)
2932
}
3033

31-
// load image
3234
guard let imageTensor = try? loadJpegAsTensor(from: image) else {
3335
print("Error: Failed to load image \(image). Check file exists and has JPEG format")
3436
printUsage()
3537
exit(1)
3638
}
3739

38-
// init model
40+
// Init the model.
3941
var style = TransformerNet()
4042
do {
4143
try importWeights(&style, from: config.weights!)
@@ -45,6 +47,8 @@ do {
4547
exit(1)
4648
}
4749

50+
// Apply the model to image.
4851
let out = style(imageTensor.expandingShape(at: 0))
52+
4953
saveTensorAsJpeg(out.squeezingShape(at: 0), to: output)
50-
print("Written output to \(output)")
54+
print("Written output to \(output)")
Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
import TensorFlow
22

3-
/// Layer for padding with reflection over mini-batch of images
4-
/// Expected input layout is BxHxWxC
5-
public struct ReflectionPad2d<Scalar: TensorFlowFloatingPoint>: Layer {
3+
/// A 2-D layer applying padding with reflection over a mini-batch.
4+
public struct ReflectionPad2D<Scalar: TensorFlowFloatingPoint>: Layer {
5+
/// The padding values along the spatial dimensions.
66
@noDerivative public let padding: ((Int, Int), (Int, Int))
77

8+
/// Creates a reflect-padding 2D Layer.
9+
///
10+
/// - Parameter padding: A tuple of 2 tuples of two integers describing how many elements to
11+
/// be padded at the beginning and end of each padding dimensions.
812
public init(padding: ((Int, Int), (Int, Int))) {
913
self.padding = padding
1014
}
1115

16+
/// Creates a reflect-padding 2D Layer.
17+
///
18+
/// - Parameter padding: Integer that describes how many elements to be padded
19+
/// at the beginning and end of each padding dimensions.
1220
public init(padding: Int) {
1321
self.padding = ((padding, padding), (padding, padding))
1422
}
1523

24+
/// Returns the output obtained from applying the layer to the given input.
25+
///
26+
/// - Parameter input: The input to the layer. Expected layout is BxHxWxC.
27+
/// - Returns: The output.
1628
@differentiable
1729
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
30+
// Padding applied to height and width dimensions only.
1831
return input.paddedWithReflection(forSizes: [
1932
(0, 0),
2033
padding.0,
@@ -24,11 +37,14 @@ public struct ReflectionPad2d<Scalar: TensorFlowFloatingPoint>: Layer {
2437
}
2538
}
2639

27-
28-
/// Layer applying relu activation function
40+
/// A layer applying `relu` activation function.
2941
public struct ReLU<Scalar: TensorFlowFloatingPoint>: Layer {
42+
/// Returns the output obtained from applying the layer to the given input.
43+
///
44+
/// - Parameter input: The input to the layer.
45+
/// - Returns: The output.
3046
@differentiable
3147
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
3248
return relu(input)
3349
}
34-
}
50+
}

FastStyleTransfer/Layers/Normalization.swift

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
11
import TensorFlow
22

3-
/// Layer that applies instance normalization over a mini-batch of images
4-
/// Expected input layout is BxHxWxC
3+
/// 2-D layer applying instance normalization over a mini-batch of inputs.
4+
///
55
/// Reference: [Instance Normalization](https://arxiv.org/abs/1607.08022)
6-
public struct InstanceNorm2d<Scalar: TensorFlowFloatingPoint>: Layer {
6+
public struct InstanceNorm2D<Scalar: TensorFlowFloatingPoint>: Layer {
7+
/// Learnable parameter scale for affine transformation.
78
public var scale: Tensor<Scalar>
9+
/// Learnable parameter offset for affine transformation.
810
public var offset: Tensor<Scalar>
11+
/// Small value added in denominator for numerical stability.
912
@noDerivative public var epsilon: Tensor<Scalar>
1013

14+
/// Creates a instance normalization 2D Layer.
15+
///
16+
/// - Parameters:
17+
/// - featureCount: Size of the channel axis in the expected input.
18+
/// - epsilon: Small scalar added for numerical stability.
1119
public init(featureCount: Int, epsilon: Tensor<Scalar> = Tensor(1e-5)) {
1220
self.epsilon = epsilon
1321
scale = Tensor<Scalar>(ones: [featureCount])
1422
offset = Tensor<Scalar>(zeros: [featureCount])
1523
}
1624

25+
/// Returns the output obtained from applying the layer to the given input.
26+
///
27+
/// - Parameter input: The input to the layer. Expected input layout is BxHxWxC.
28+
/// - Returns: The output.
1729
@differentiable
1830
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
31+
// Calculate mean & variance along H,W axes.
1932
let mean = input.mean(alongAxes: [1, 2])
2033
let variance = input.variance(alongAxes: [1, 2])
2134
let norm = (input - mean) * rsqrt(variance + epsilon)

0 commit comments

Comments
 (0)