|
7 | 7 | import torch.nn.functional as F
|
8 | 8 | import torchvision
|
9 | 9 | import torchvision.transforms as transforms
|
| 10 | +from packaging import version |
10 | 11 |
|
11 | 12 |
|
12 | 13 | def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
|
13 | 14 | transform = transforms.Compose(
|
14 | 15 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
15 | 16 | )
|
16 | 17 |
|
| 18 | + # Temporary Change to allow the test to run with pytorch 1.7 RC3 |
| 19 | + # Smdebug breaks when num_workers>0 for Pytorch 1.7.0 |
| 20 | + if version.parse(torch.__version__) >= version.parse("1.7.0"): |
| 21 | + num_workers = 0 |
| 22 | + else: |
| 23 | + num_workers = 2 |
| 24 | + |
17 | 25 | trainset = torchvision.datasets.CIFAR10(
|
18 | 26 | root="./data", train=True, download=True, transform=transform
|
19 | 27 | )
|
20 |
| - trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) |
| 28 | + trainloader = torch.utils.data.DataLoader( |
| 29 | + trainset, batch_size=4, shuffle=True, num_workers=num_workers |
| 30 | + ) |
21 | 31 |
|
22 | 32 | testset = torchvision.datasets.CIFAR10(
|
23 | 33 | root="./data", train=False, download=True, transform=transform
|
24 | 34 | )
|
25 |
| - testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) |
| 35 | + testloader = torch.utils.data.DataLoader( |
| 36 | + testset, batch_size=4, shuffle=False, num_workers=num_workers |
| 37 | + ) |
26 | 38 |
|
27 | 39 | classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
|
28 | 40 | return trainloader, testloader
|
|
0 commit comments