Skip to content

Commit cf3a688

Browse files
authored
DLFW changes (#2552)
1 parent 593ff44 commit cf3a688

File tree

2 files changed

+2
-32
lines changed

2 files changed

+2
-32
lines changed

examples/int8/training/vgg16/main.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
import torch.nn.functional as F
99
import torch.optim as optim
1010
import torch.utils.data as data
11-
import torchvision.transforms as transforms
1211
import torchvision.datasets as datasets
13-
14-
from torch.utils.tensorboard import SummaryWriter
15-
12+
import torchvision.transforms as transforms
1613
from vgg16 import vgg16
1714

1815
PARSER = argparse.ArgumentParser(
@@ -64,7 +61,6 @@
6461

6562
timestamp = datetime.timestamp(now)
6663

67-
writer = SummaryWriter(args.tensorboard + "/test_" + str(timestamp))
6864
classes = (
6965
"plane",
7066
"car",
@@ -82,7 +78,6 @@
8278
def main():
8379
global state
8480
global classes
85-
global writer
8681
if not os.path.isdir(args.ckpt_dir):
8782
os.makedirs(args.ckpt_dir)
8883

@@ -131,9 +126,6 @@ def main():
131126
data = iter(training_dataloader)
132127
images, _ = next(data)
133128

134-
writer.add_graph(model, images.cuda())
135-
writer.close()
136-
137129
crit = nn.CrossEntropyLoss()
138130
opt = optim.SGD(
139131
model.parameters(),
@@ -156,8 +148,6 @@ def main():
156148

157149
for epoch in range(args.start_from, args.epochs):
158150
adjust_lr(opt, epoch)
159-
writer.add_scalar("Learning Rate", state["lr"], epoch)
160-
writer.close()
161151
print("Epoch: [%5d / %5d] LR: %f" % (epoch + 1, args.epochs, state["lr"]))
162152

163153
train(model, training_dataloader, crit, opt, epoch)
@@ -179,7 +169,6 @@ def main():
179169

180170

181171
def train(model, dataloader, crit, opt, epoch):
182-
global writer
183172
model.train()
184173
running_loss = 0.0
185174
for batch, (data, labels) in enumerate(dataloader):
@@ -192,10 +181,6 @@ def train(model, dataloader, crit, opt, epoch):
192181

193182
running_loss += loss.item()
194183
if batch % 50 == 49:
195-
writer.add_scalar(
196-
"Training Loss", running_loss / 100, epoch * len(dataloader) + batch
197-
)
198-
writer.close()
199184
print(
200185
"Batch: [%5d | %5d] loss: %.3f"
201186
% (batch + 1, len(dataloader), running_loss / 100)
@@ -204,7 +189,6 @@ def train(model, dataloader, crit, opt, epoch):
204189

205190

206191
def test(model, dataloader, crit, epoch):
207-
global writer
208192
global classes
209193
total = 0
210194
correct = 0
@@ -223,12 +207,6 @@ def test(model, dataloader, crit, epoch):
223207
total += labels.size(0)
224208
correct += (preds == labels).sum().item()
225209

226-
writer.add_scalar("Testing Loss", loss / total, epoch)
227-
writer.close()
228-
229-
writer.add_scalar("Testing Accuracy", correct / total * 100, epoch)
230-
writer.close()
231-
232210
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
233211
test_preds = torch.cat(class_preds)
234212
for i in range(len(classes)):
@@ -263,14 +241,6 @@ def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0)
263241
tensorboard_preds = test_preds == class_index
264242
tensorboard_probs = test_probs[:, class_index]
265243

266-
writer.add_pr_curve(
267-
classes[class_index],
268-
tensorboard_preds,
269-
tensorboard_probs,
270-
global_step=global_step,
271-
)
272-
writer.close()
273-
274244

275245
if __name__ == "__main__":
276246
main()

notebooks/EfficientNet-example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@
526526
"# The compiled module will have precision as specified by \"op_precision\".\n",
527527
"# Here, it will have FP32 precision.\n",
528528
"trt_model_fp32 = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((128, 3, 224, 224), dtype=torch.float32)],\n",
529-
" enabled_precisions = torch.float32, # Run with FP32\n",
529+
" enabled_precisions = {torch.float32}, # Run with FP32\n",
530530
" workspace_size = 1 << 22\n",
531531
")"
532532
]

0 commit comments

Comments
 (0)