Skip to content

Commit ef3b6b2

Browse files
authored
pytorch tmp (aws#382)
1 parent 301cbdd commit ef3b6b2

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

tests/zero_code_change/pt_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,34 @@
77
import torch.nn.functional as F
88
import torchvision
99
import torchvision.transforms as transforms
10+
from packaging import version
1011

1112

1213
def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
1314
transform = transforms.Compose(
1415
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
1516
)
1617

18+
# Temporary Change to allow the test to run with pytorch 1.7 RC3
19+
# Smdebug breaks when num_workers>0 for Pytorch 1.7.0
20+
if version.parse(torch.__version__) >= version.parse("1.7.0"):
21+
num_workers = 0
22+
else:
23+
num_workers = 2
24+
1725
trainset = torchvision.datasets.CIFAR10(
1826
root="./data", train=True, download=True, transform=transform
1927
)
20-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
28+
trainloader = torch.utils.data.DataLoader(
29+
trainset, batch_size=4, shuffle=True, num_workers=num_workers
30+
)
2131

2232
testset = torchvision.datasets.CIFAR10(
2333
root="./data", train=False, download=True, transform=transform
2434
)
25-
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
35+
testloader = torch.utils.data.DataLoader(
36+
testset, batch_size=4, shuffle=False, num_workers=num_workers
37+
)
2638

2739
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
2840
return trainloader, testloader

0 commit comments

Comments
 (0)