Skip to content

Commit 4db0b9f

Browse files
2winsDEKHTIARJonathan
authored andcommitted
Create tutorial_tfslim.py (#560)
* Create tutorial_tfslim.py fixes #552 * Update tutorial_tfslim.py
1 parent 74a59ce commit 4db0b9f

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

example/tutorial_tfslim.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import time
5+
import tensorflow as tf
6+
import tensorlayer as tl
7+
import tensorflow.contrib.slim as slim
8+
from tensorlayer.layers import *
9+
10+
X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 784))
11+
12+
sess = tf.InteractiveSession()
13+
14+
batch_size = 128
15+
x = tf.placeholder(tf.float32, shape=[None, 784])
16+
y_ = tf.placeholder(tf.int64, shape=[None])
17+
is_training = tf.placeholder(tf.bool)
18+
19+
20+
def slim_block(x):
21+
with tf.variable_scope('tf_slim'):
22+
x = slim.dropout(x, 0.8, is_training=is_training)
23+
x = slim.fully_connected(x, 800, activation_fn=tf.nn.relu)
24+
x = slim.dropout(x, 0.5, is_training=is_training)
25+
x = slim.fully_connected(x, 800, activation_fn=tf.nn.relu)
26+
x = slim.dropout(x, 0.5, is_training=is_training)
27+
logits = slim.fully_connected(x, 10, activation_fn=tf.identity)
28+
return logits, {}
29+
30+
31+
network = InputLayer(x, name='input')
32+
network = SlimNetsLayer(network, slim_layer=slim_block, name='tf_slim')
33+
34+
y = network.outputs
35+
network.print_params(False)
36+
network.print_layers()
37+
38+
cost = tl.cost.cross_entropy(y, y_, 'cost')
39+
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
40+
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
41+
42+
n_epoch = 200
43+
learning_rate = 0.0001
44+
45+
train_params = network.all_params
46+
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost, var_list=train_params)
47+
48+
tl.layers.initialize_global_variables(sess)
49+
50+
for epoch in range(n_epoch):
51+
start_time = time.time()
52+
## Training
53+
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
54+
_, _ = sess.run([cost, train_op], feed_dict={x: X_train_a, y_: y_train_a, is_training: True})
55+
56+
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
57+
## Evaluation
58+
train_loss, train_acc, n_batch = 0, 0, 0
59+
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=False):
60+
err, ac = sess.run([cost, acc], feed_dict={x: X_train_a, y_: y_train_a, is_training: False})
61+
train_loss += err
62+
train_acc += ac
63+
n_batch += 1
64+
print(" train loss: %f" % (train_loss / n_batch))
65+
print(" train acc: %f" % (train_acc / n_batch))
66+
val_loss, val_acc, n_batch = 0, 0, 0
67+
for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=False):
68+
err, ac = sess.run([cost, acc], feed_dict={x: X_val_a, y_: y_val_a, is_training: False})
69+
val_loss += err
70+
val_acc += ac
71+
n_batch += 1
72+
print(" val loss: %f" % (val_loss / n_batch))
73+
print(" val acc: %f" % (val_acc / n_batch))

0 commit comments

Comments
 (0)