Skip to content

Commit cdfbdd1

Browse files
authored
Fix data parallel sanity check (aws#169)
1 parent 58488b9 commit cdfbdd1

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tests/pytorch/test_data_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ def test_data_parallel():
3535
train(model, hook, torch.device(device), optimizer, num_steps=10)
3636

3737
trial = create_trial(out_dir)
38+
assert trial.steps() == [0, 1, 5]
3839
if device == "cpu":
39-
assert len(trial.tensor_names()) == 36
40+
assert len(trial.tensor_names()) == 37
4041
else:
41-
assert len(trial.tensor_names()) > 36
42+
assert len(trial.tensor_names()) > 37
4243

4344
shutil.rmtree(out_dir, ignore_errors=True)

tests/pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False):
4040
hook.set_mode(modes.TRAIN)
4141

4242
model.train()
43-
count = 0
4443
# for batch_idx, (data, target) in enumerate(train_loader):
4544
for i in range(num_steps):
4645
batch_size = 32
@@ -49,6 +48,7 @@ def train(model, hook, device, optimizer, num_steps=500, set_modes=False):
4948
optimizer.zero_grad()
5049
output = model(Variable(data, requires_grad=True))
5150
loss = F.nll_loss(output, target)
51+
hook.record_tensor_value("nll_loss", tensor_value=loss)
5252
loss.backward()
5353
optimizer.step()
5454

0 commit comments

Comments
 (0)