@@ -39,14 +39,14 @@ def forward(self, x):
39
39
return F .log_softmax (x , dim = 1 )
40
40
41
41
42
- def _get_train_data_loader (training_dir , is_distributed , ** kwargs ):
42
+ def _get_train_data_loader (training_dir , is_distributed , batch_size , ** kwargs ):
43
43
logger .info ('Get train data loader' )
44
44
dataset = datasets .MNIST (training_dir , train = True , transform = transforms .Compose ([
45
45
transforms .ToTensor (),
46
46
transforms .Normalize ((0.1307 ,), (0.3081 ,))
47
47
]))
48
48
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset ) if is_distributed else None
49
- train_loader = torch .utils .data .DataLoader (dataset , batch_size = 64 , shuffle = train_sampler is None ,
49
+ train_loader = torch .utils .data .DataLoader (dataset , batch_size = batch_size , shuffle = train_sampler is None ,
50
50
sampler = train_sampler , ** kwargs )
51
51
return train_sampler , train_loader
52
52
@@ -94,7 +94,7 @@ def train(args):
94
94
if use_cuda :
95
95
torch .cuda .manual_seed (seed )
96
96
97
- train_sampler , train_loader = _get_train_data_loader (args .data_dir , is_distributed , ** kwargs )
97
+ train_sampler , train_loader = _get_train_data_loader (args .data_dir , is_distributed , args . batch_size , ** kwargs )
98
98
test_loader = _get_test_data_loader (args .data_dir , ** kwargs )
99
99
100
100
logger .debug ('Processes {}/{} ({:.0f}%) of train data' .format (
@@ -142,9 +142,11 @@ def train(args):
142
142
logger .debug ('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}' .format (
143
143
epoch , batch_idx * len (data ), len (train_loader .sampler ),
144
144
100. * batch_idx / len (train_loader ), loss .item ()))
145
- test (model , test_loader , device )
145
+ accuracy = test (model , test_loader , device )
146
146
save_model (model , args .model_dir )
147
147
148
+ logger .debug ('Overall test accuracy: {}' .format (accuracy ))
149
+
148
150
149
151
def test (model , test_loader , device ):
150
152
model .eval ()
@@ -159,9 +161,12 @@ def test(model, test_loader, device):
159
161
correct += pred .eq (target .view_as (pred )).sum ().item ()
160
162
161
163
test_loss /= len (test_loader .dataset )
164
+ accuracy = 100. * correct / len (test_loader .dataset )
165
+
162
166
logger .debug ('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
163
- test_loss , correct , len (test_loader .dataset ),
164
- 100. * correct / len (test_loader .dataset )))
167
+ test_loss , correct , len (test_loader .dataset ), accuracy ))
168
+
169
+ return accuracy
165
170
166
171
167
172
def model_fn (model_dir ):
@@ -181,6 +186,7 @@ def save_model(model, model_dir):
181
186
if __name__ == '__main__' :
182
187
parser = argparse .ArgumentParser ()
183
188
parser .add_argument ('--epochs' , type = int , default = 1 , metavar = 'N' )
189
+ parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' )
184
190
185
191
# Container environment
186
192
parser .add_argument ('--hosts' , type = list , default = json .loads (os .environ ['SM_HOSTS' ]))
0 commit comments