Skip to content

Commit cbab4ca

Browse files
author
Chuyang Deng
committed
fix mnist script
1 parent a40d345 commit cbab4ca

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/data/pytorch_mnist/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
4444
dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([
4545
transforms.ToTensor(),
4646
transforms.Normalize((0.1307,), (0.3081,))
47-
]))
47+
]), download=True)
4848
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
4949
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None,
5050
sampler=train_sampler, **kwargs)
@@ -57,7 +57,7 @@ def _get_test_data_loader(training_dir, **kwargs):
5757
datasets.MNIST(training_dir, train=False, transform=transforms.Compose([
5858
transforms.ToTensor(),
5959
transforms.Normalize((0.1307,), (0.3081,))
60-
])),
60+
]), download=True),
6161
batch_size=1000, shuffle=True, **kwargs)
6262

6363

0 commit comments

Comments
 (0)