Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit b8db179

Browse files
rickwierengaBradLarson
authored andcommitted
Add DCGAN (#261)
* add dcgan * add batch size * use swift-format * format comments * update to work with master * Use saveImage instead of matplotlib * add copyright notice * remove placeholder labels * remove unnecessary comments
1 parent 82d1e0c commit b8db179

File tree

3 files changed

+201
-0
lines changed

3 files changed

+201
-0
lines changed

DCGAN/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Deep Convolutional Generative Adversarial Network
2+
3+
arXiv: https://arxiv.org/abs/1406.2661
4+
5+
After Epoch 1:
6+
<p align="center">
7+
<img src="images/epoch-1-output.png" height="270" width="360">
8+
</p>
9+
10+
After Epoch 10:
11+
<p align="center">
12+
<img src="images/epoch-10-output.png" height="270" width="360">
13+
</p>
14+
15+
16+
## Tutorial
17+
18+
You can read the tutorial on creating this model [here](https://rickwierenga.com/blog/s4tf/s4tf-gan.html) (rickwierenga.com).
19+
20+
## Setup
21+
22+
To begin, you'll need the [latest version of Swift for
23+
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
24+
installed. Make sure you've added the correct version of `swift` to your path.
25+
26+
To train the model, run:
27+
28+
```sh
29+
swift run DCGAN
30+
```

DCGAN/main.swift

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Datasets
16+
import Foundation
17+
import ModelSupport
18+
import TensorFlow
19+
20+
let batchSize = 512
21+
let mnist = MNIST(flattening: false, normalizing: true)
22+
23+
let outputFolder = "./output/"
24+
25+
let zDim = 100
26+
27+
// MARK: - Models
28+
29+
// MARK: Generator
30+
31+
struct Generator: Layer {
32+
var flatten = Flatten<Float>()
33+
34+
var dense1 = Dense<Float>(inputSize: zDim, outputSize: 7 * 7 * 256)
35+
var batchNorm1 = BatchNorm<Float>(featureCount: 7 * 7 * 256)
36+
var transConv2D1 = TransposedConv2D<Float>(
37+
filterShape: (5, 5, 128, 256),
38+
strides: (1, 1),
39+
padding: .same
40+
)
41+
var batchNorm2 = BatchNorm<Float>(featureCount: 7 * 7 * 128)
42+
var transConv2D2 = TransposedConv2D<Float>(
43+
filterShape: (5, 5, 64, 128),
44+
strides: (2, 2),
45+
padding: .same
46+
)
47+
var batchNorm3 = BatchNorm<Float>(featureCount: 14 * 14 * 64)
48+
var transConv2D3 = TransposedConv2D<Float>(
49+
filterShape: (5, 5, 1, 64),
50+
strides: (2, 2),
51+
padding: .same
52+
)
53+
54+
@differentiable
55+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
56+
let x1 = leakyRelu(input.sequenced(through: dense1, batchNorm1))
57+
let x1Reshape = x1.reshaped(to: TensorShape(x1.shape.contiguousSize / (7 * 7 * 256), 7, 7, 256))
58+
let x2 = leakyRelu(x1Reshape.sequenced(through: transConv2D1, flatten, batchNorm2))
59+
let x2Reshape = x2.reshaped(to: TensorShape(x2.shape.contiguousSize / (7 * 7 * 128), 7, 7, 128))
60+
let x3 = leakyRelu(x2Reshape.sequenced(through: transConv2D2, flatten, batchNorm3))
61+
let x3Reshape = x3.reshaped(to: TensorShape(x3.shape.contiguousSize / (14 * 14 * 64), 14, 14, 64))
62+
return tanh(transConv2D3(x3Reshape))
63+
}
64+
}
65+
66+
@differentiable
67+
func generatorLoss(fakeLabels: Tensor<Float>) -> Tensor<Float> {
68+
sigmoidCrossEntropy(logits: fakeLabels,
69+
labels: Tensor(ones: fakeLabels.shape))
70+
}
71+
72+
// MARK: Discriminator
73+
74+
struct Discriminator: Layer {
75+
var conv2D1 = Conv2D<Float>(
76+
filterShape: (5, 5, 1, 64),
77+
strides: (2, 2),
78+
padding: .same
79+
)
80+
var dropout = Dropout<Float>(probability: 0.3)
81+
var conv2D2 = Conv2D<Float>(
82+
filterShape: (5, 5, 64, 128),
83+
strides: (2, 2),
84+
padding: .same
85+
)
86+
var flatten = Flatten<Float>()
87+
var dense = Dense<Float>(inputSize: 6272, outputSize: 1)
88+
89+
@differentiable
90+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
91+
let x1 = dropout(leakyRelu(conv2D1(input)))
92+
let x2 = dropout(leakyRelu(conv2D2(x1)))
93+
return x2.sequenced(through: flatten, dense)
94+
}
95+
}
96+
97+
@differentiable
98+
func discriminatorLoss(realLabels: Tensor<Float>, fakeLabels: Tensor<Float>) -> Tensor<Float> {
99+
let realLoss = sigmoidCrossEntropy(logits: realLabels,
100+
labels: Tensor(ones: realLabels.shape))
101+
let fakeLoss = sigmoidCrossEntropy(logits: fakeLabels,
102+
labels: Tensor(zeros: fakeLabels.shape))
103+
return realLoss + fakeLoss
104+
}
105+
106+
// MARK: - Training
107+
108+
// Create instances of models.
109+
var discriminator = Discriminator()
110+
var generator = Generator()
111+
112+
// Define optimizers.
113+
let optG = Adam(for: generator, learningRate: 0.0001)
114+
let optD = Adam(for: discriminator, learningRate: 0.0001)
115+
116+
// Test noise so we can track progress.
117+
let noise = Tensor<Float>(randomNormal: TensorShape(1, zDim))
118+
119+
print("Begin training...")
120+
let epochs = 20
121+
for epoch in 0 ... epochs {
122+
Context.local.learningPhase = .training
123+
let trainingShuffled = mnist.trainingDataset.shuffled(sampleCount: mnist.trainingExampleCount, randomSeed: Int64(epoch))
124+
for batch in trainingShuffled.batched(batchSize) {
125+
let realImages = batch.data
126+
127+
// Train generator.
128+
let noiseG = Tensor<Float>(randomNormal: TensorShape(batchSize, zDim))
129+
let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
130+
let fakeImages = generator(noiseG)
131+
let fakeLabels = discriminator(fakeImages)
132+
let loss = generatorLoss(fakeLabels: fakeLabels)
133+
return loss
134+
}
135+
optG.update(&generator, along: 𝛁generator)
136+
137+
// Train discriminator.
138+
let noiseD = Tensor<Float>(randomNormal: TensorShape(batchSize, zDim))
139+
let fakeImages = generator(noiseD)
140+
141+
let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
142+
let realLabels = discriminator(realImages)
143+
let fakeLabels = discriminator(fakeImages)
144+
let loss = discriminatorLoss(realLabels: realLabels, fakeLabels: fakeLabels)
145+
return loss
146+
}
147+
optD.update(&discriminator, along: 𝛁discriminator)
148+
}
149+
150+
// Test the networks.
151+
Context.local.learningPhase = .inference
152+
153+
// Render images.
154+
let generatedImage = generator(noise)
155+
try saveImage(
156+
generatedImage, size: (28, 28), directory: outputFolder,
157+
name: "\(epoch).jpg")
158+
159+
// Print loss.
160+
let generatorLoss_ = generatorLoss(fakeLabels: generatedImage)
161+
print("epoch: \(epoch) | Generator loss: \(generatorLoss_)")
162+
}
163+
164+
// Generate another image.
165+
let noise1 = Tensor<Float>(randomNormal: TensorShape(1, 100))
166+
let generatedImage = generator(noise1)
167+
try saveImage(
168+
generatedImage, size: (28, 28), directory: outputFolder,
169+
name: "final.jpg")

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ let package = Package(
1919
.executable(name: "MiniGoDemo", targets: ["MiniGoDemo"]),
2020
.library(name: "MiniGo", targets: ["MiniGo"]),
2121
.executable(name: "GAN", targets: ["GAN"]),
22+
.executable(name: "DCGAN", targets: ["DCGAN"]),
2223
.executable(name: "FastStyleTransferDemo", targets: ["FastStyleTransferDemo"]),
2324
.library(name: "FastStyleTransfer", targets: ["FastStyleTransfer"]),
2425
.executable(name: "Benchmarks", targets: ["Benchmarks"]),
@@ -57,6 +58,7 @@ let package = Package(
5758
.testTarget(name: "DatasetsTests", dependencies: ["Datasets"]),
5859
.target(name: "Transformer", path: "Transformer"),
5960
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
61+
.target(name: "DCGAN", dependencies: ["Datasets", "ModelSupport"], path: "DCGAN"),
6062
.target(name: "FastStyleTransfer", path: "FastStyleTransfer", exclude: ["Demo"]),
6163
.target(name: "FastStyleTransferDemo", dependencies: ["FastStyleTransfer"],
6264
path: "FastStyleTransfer/Demo"),

0 commit comments

Comments
 (0)