Skip to content

Commit 3704557

Browse files
committed
Small copy update
1 parent 9c8c54a commit 3704557

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

examples/nets/unet.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Demonstrate Unet training on label-maker prepared data"""
2+
3+
from __future__ import print_function
4+
import numpy as np
5+
import keras
6+
from keras.models import Model
7+
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose
8+
from keras.optimizers import Adam
9+
from keras.callbacks import ModelCheckpoint
10+
from keras.preprocessing.image import ImageDataGenerator
11+
from keras import backend as K
12+
13+
batch_size = 16
14+
num_classes = 2
15+
epochs = 100
16+
17+
smooth = 1.
18+
19+
# input image dimensions
20+
img_rows, img_cols = 256, 256
21+
22+
# the data, shuffled and split between train and test sets
23+
npz = np.load('data.npz')
24+
x_train = npz['x_train']
25+
y_train = npz['y_train']
26+
x_test = npz['x_test']
27+
y_test = npz['y_test']
28+
29+
if K.image_data_format() == 'channels_first':
30+
x_train = x_train.reshape(x_train.shape[0], 3, img_rows, img_cols)
31+
x_test = x_test.reshape(x_test.shape[0], 3, img_rows, img_cols)
32+
input_shape = (3, img_rows, img_cols)
33+
else:
34+
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 3)
35+
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 3)
36+
input_shape = (img_rows, img_cols, 3)
37+
38+
def dice_coef(y_true, y_pred):
39+
y_true_f = K.flatten(y_true)
40+
y_pred_f = K.flatten(y_pred)
41+
intersection = K.sum(y_true_f * y_pred_f)
42+
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
43+
44+
45+
def dice_coef_loss(y_true, y_pred):
46+
return -dice_coef(y_true, y_pred)
47+
48+
49+
def get_unet():
50+
inputs = Input(input_shape)
51+
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
52+
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
53+
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
54+
55+
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
56+
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
57+
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
58+
59+
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
60+
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
61+
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
62+
63+
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
64+
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
65+
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
66+
67+
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
68+
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
69+
70+
up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
71+
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
72+
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
73+
74+
up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
75+
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
76+
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
77+
78+
up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
79+
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
80+
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
81+
82+
up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
83+
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
84+
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
85+
86+
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
87+
88+
model = Model(inputs=[inputs], outputs=[conv10])
89+
90+
model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef])
91+
92+
return model
93+
94+
95+
x_train = x_train.astype('float32')
96+
x_test = x_test.astype('float32')
97+
98+
print('x_train shape:', x_train.shape)
99+
print(x_train.shape[0], 'train samples')
100+
print(x_test.shape[0], 'test samples')
101+
102+
x_train /= 255
103+
x_test /= 255
104+
105+
# normalize the images
106+
img_mean = np.mean(x_train, axis=(0, 1, 2))
107+
img_std = np.std(x_train, axis=(0, 1, 2))
108+
x_train -= img_mean
109+
x_train /= img_std
110+
111+
x_test -= img_mean
112+
x_test /= img_std
113+
114+
115+
datagen = ImageDataGenerator(
116+
rotation_range=180, # randomly rotate images in the range (degrees, 0 to 180)
117+
horizontal_flip=True, # randomly flip images
118+
vertical_flip=False
119+
)
120+
121+
model = get_unet()
122+
123+
# Fit the model on the batches generated by datagen.flow().
124+
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
125+
steps_per_epoch=int(x_train.shape[0] / batch_size),
126+
epochs=epochs,
127+
validation_data=(x_test, y_test),
128+
verbose=1,
129+
workers=4)
130+
131+
score = model.evaluate(x_test, y_test, verbose=0)
132+
print('Test loss:', score[0])
133+
print('Test accuracy:', score[1])
134+
model.save('model.h5')

examples/walkthrough-tensorflow-object-detection.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Google TensorFlow Object Detection API is an open source framework built on top
88

99
First install Label Maker (`pip install label-maker`), [tippecanoe](https://github.com/mapbox/tippecanoe) and Pandas (`pip install pandas`).
1010

11-
**Note:** *If you want to learn how TensorFlow object detection works and how to setup the workflow, you should follow following instructions step by step. If you want to skip the steps and automate the workflow, use our created **Dockerfile** [follow this instruction](https://github.com/Rub21/tensorflow-building-detection) instead*
11+
**Note:** *If you want to learn how TensorFlow object detection works and how to setup the workflow, you should follow these instructions step by step. If you want to skip the steps and automate the workflow, you can use our docker image and [follow these instructions](https://github.com/Rub21/tensorflow-building-detection) instead.*
1212

1313
## Create the training dataset
1414

0 commit comments

Comments
 (0)