Skip to content

Commit 7783b31

Browse files
committed
add original PyTorch example codes of MNIST
1 parent 0157cc3 commit 7783b31

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

examples/MNIST/MNIST_codes/main.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import argparse
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torch.optim as optim
6+
from torchvision import datasets, transforms
7+
from torch.optim.lr_scheduler import StepLR
8+
9+
10+
class Net(nn.Module):
11+
def __init__(self):
12+
super(Net, self).__init__()
13+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
14+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
15+
self.dropout1 = nn.Dropout(0.25)
16+
self.dropout2 = nn.Dropout(0.5)
17+
self.fc1 = nn.Linear(9216, 128)
18+
self.fc2 = nn.Linear(128, 10)
19+
20+
def forward(self, x):
21+
x = self.conv1(x)
22+
x = F.relu(x)
23+
x = self.conv2(x)
24+
x = F.relu(x)
25+
x = F.max_pool2d(x, 2)
26+
x = self.dropout1(x)
27+
x = torch.flatten(x, 1)
28+
x = self.fc1(x)
29+
x = F.relu(x)
30+
x = self.dropout2(x)
31+
x = self.fc2(x)
32+
output = F.log_softmax(x, dim=1)
33+
return output
34+
35+
36+
def train(args, model, device, train_loader, optimizer, epoch):
37+
model.train()
38+
for batch_idx, (data, target) in enumerate(train_loader):
39+
data, target = data.to(device), target.to(device)
40+
optimizer.zero_grad()
41+
output = model(data)
42+
loss = F.nll_loss(output, target)
43+
loss.backward()
44+
optimizer.step()
45+
if batch_idx % args.log_interval == 0:
46+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
47+
epoch, batch_idx * len(data), len(train_loader.dataset),
48+
100. * batch_idx / len(train_loader), loss.item()))
49+
if args.dry_run:
50+
break
51+
52+
53+
def test(model, device, test_loader):
54+
model.eval()
55+
test_loss = 0
56+
correct = 0
57+
with torch.no_grad():
58+
for data, target in test_loader:
59+
data, target = data.to(device), target.to(device)
60+
output = model(data)
61+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
62+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
63+
correct += pred.eq(target.view_as(pred)).sum().item()
64+
65+
test_loss /= len(test_loader.dataset)
66+
67+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
68+
test_loss, correct, len(test_loader.dataset),
69+
100. * correct / len(test_loader.dataset)))
70+
71+
72+
def main():
73+
# Training settings
74+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
75+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
76+
help='input batch size for training (default: 64)')
77+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
78+
help='input batch size for testing (default: 1000)')
79+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
80+
help='number of epochs to train (default: 14)')
81+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
82+
help='learning rate (default: 1.0)')
83+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
84+
help='Learning rate step gamma (default: 0.7)')
85+
parser.add_argument('--no-cuda', action='store_true', default=False,
86+
help='disables CUDA training')
87+
parser.add_argument('--no-mps', action='store_true', default=False,
88+
help='disables macOS GPU training')
89+
parser.add_argument('--dry-run', action='store_true', default=False,
90+
help='quickly check a single pass')
91+
parser.add_argument('--seed', type=int, default=1, metavar='S',
92+
help='random seed (default: 1)')
93+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
94+
help='how many batches to wait before logging training status')
95+
parser.add_argument('--save-model', action='store_true', default=False,
96+
help='For Saving the current Model')
97+
args = parser.parse_args()
98+
use_cuda = not args.no_cuda and torch.cuda.is_available()
99+
use_mps = not args.no_mps and torch.backends.mps.is_available()
100+
101+
torch.manual_seed(args.seed)
102+
103+
if use_cuda:
104+
device = torch.device("cuda")
105+
elif use_mps:
106+
device = torch.device("mps")
107+
else:
108+
device = torch.device("cpu")
109+
110+
train_kwargs = {'batch_size': args.batch_size}
111+
test_kwargs = {'batch_size': args.test_batch_size}
112+
if use_cuda:
113+
cuda_kwargs = {'num_workers': 1,
114+
'pin_memory': True,
115+
'shuffle': True}
116+
train_kwargs.update(cuda_kwargs)
117+
test_kwargs.update(cuda_kwargs)
118+
119+
transform=transforms.Compose([
120+
transforms.ToTensor(),
121+
transforms.Normalize((0.1307,), (0.3081,))
122+
])
123+
dataset1 = datasets.MNIST('../data', train=True, download=True,
124+
transform=transform)
125+
dataset2 = datasets.MNIST('../data', train=False,
126+
transform=transform)
127+
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
128+
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
129+
130+
model = Net().to(device)
131+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
132+
133+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
134+
for epoch in range(1, args.epochs + 1):
135+
train(args, model, device, train_loader, optimizer, epoch)
136+
test(model, device, test_loader)
137+
scheduler.step()
138+
139+
if args.save_model:
140+
torch.save(model.state_dict(), "mnist_cnn.pt")
141+
142+
143+
if __name__ == '__main__':
144+
main()

0 commit comments

Comments
 (0)