8
8
import torch .nn .functional as F
9
9
import torch .optim as optim
10
10
import torch .utils .data as data
11
- import torchvision .transforms as transforms
12
11
import torchvision .datasets as datasets
13
-
14
- from torch .utils .tensorboard import SummaryWriter
15
-
12
+ import torchvision .transforms as transforms
16
13
from vgg16 import vgg16
17
14
18
15
PARSER = argparse .ArgumentParser (
64
61
65
62
timestamp = datetime .timestamp (now )
66
63
67
- writer = SummaryWriter (args .tensorboard + "/test_" + str (timestamp ))
68
64
classes = (
69
65
"plane" ,
70
66
"car" ,
82
78
def main ():
83
79
global state
84
80
global classes
85
- global writer
86
81
if not os .path .isdir (args .ckpt_dir ):
87
82
os .makedirs (args .ckpt_dir )
88
83
@@ -131,9 +126,6 @@ def main():
131
126
data = iter (training_dataloader )
132
127
images , _ = next (data )
133
128
134
- writer .add_graph (model , images .cuda ())
135
- writer .close ()
136
-
137
129
crit = nn .CrossEntropyLoss ()
138
130
opt = optim .SGD (
139
131
model .parameters (),
@@ -156,8 +148,6 @@ def main():
156
148
157
149
for epoch in range (args .start_from , args .epochs ):
158
150
adjust_lr (opt , epoch )
159
- writer .add_scalar ("Learning Rate" , state ["lr" ], epoch )
160
- writer .close ()
161
151
print ("Epoch: [%5d / %5d] LR: %f" % (epoch + 1 , args .epochs , state ["lr" ]))
162
152
163
153
train (model , training_dataloader , crit , opt , epoch )
@@ -179,7 +169,6 @@ def main():
179
169
180
170
181
171
def train (model , dataloader , crit , opt , epoch ):
182
- global writer
183
172
model .train ()
184
173
running_loss = 0.0
185
174
for batch , (data , labels ) in enumerate (dataloader ):
@@ -192,10 +181,6 @@ def train(model, dataloader, crit, opt, epoch):
192
181
193
182
running_loss += loss .item ()
194
183
if batch % 50 == 49 :
195
- writer .add_scalar (
196
- "Training Loss" , running_loss / 100 , epoch * len (dataloader ) + batch
197
- )
198
- writer .close ()
199
184
print (
200
185
"Batch: [%5d | %5d] loss: %.3f"
201
186
% (batch + 1 , len (dataloader ), running_loss / 100 )
@@ -204,7 +189,6 @@ def train(model, dataloader, crit, opt, epoch):
204
189
205
190
206
191
def test (model , dataloader , crit , epoch ):
207
- global writer
208
192
global classes
209
193
total = 0
210
194
correct = 0
@@ -223,12 +207,6 @@ def test(model, dataloader, crit, epoch):
223
207
total += labels .size (0 )
224
208
correct += (preds == labels ).sum ().item ()
225
209
226
- writer .add_scalar ("Testing Loss" , loss / total , epoch )
227
- writer .close ()
228
-
229
- writer .add_scalar ("Testing Accuracy" , correct / total * 100 , epoch )
230
- writer .close ()
231
-
232
210
test_probs = torch .cat ([torch .stack (batch ) for batch in class_probs ])
233
211
test_preds = torch .cat (class_preds )
234
212
for i in range (len (classes )):
@@ -263,14 +241,6 @@ def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0)
263
241
tensorboard_preds = test_preds == class_index
264
242
tensorboard_probs = test_probs [:, class_index ]
265
243
266
- writer .add_pr_curve (
267
- classes [class_index ],
268
- tensorboard_preds ,
269
- tensorboard_probs ,
270
- global_step = global_step ,
271
- )
272
- writer .close ()
273
-
274
244
275
245
if __name__ == "__main__" :
276
246
main ()
0 commit comments