Skip to content

Commit 90ef076

Browse files
authored
Add loss registration to PT by calling hook.register_loss(criterion) (aws#269)
* Add loss registration to PT by calling hook.register_loss(criterion)wq * Change output0 to output_0 for consistencywq * Hide some overwhelming logging * Case-sensitive checks * Fix loss regex
1 parent 72dd7e7 commit 90ef076

File tree

7 files changed

+148
-39
lines changed

7 files changed

+148
-39
lines changed

tests/pytorch/test_loss.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
import shutil
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import torch.optim as optim
7+
8+
import tornasole.pytorch as ts
9+
from tornasole.trials import Trial, create_trial
10+
11+
class Net(nn.Module):
12+
"""CIFAR-10 classification network structure."""
13+
def __init__(self):
14+
super().__init__()
15+
self.conv1 = nn.Conv2d(3, 6, 5)
16+
self.pool = nn.MaxPool2d(2, 2)
17+
self.conv2 = nn.Conv2d(6, 16, 5)
18+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
19+
self.fc2 = nn.Linear(120, 84)
20+
self.fc3 = nn.Linear(84, 10)
21+
22+
def forward(self, x):
23+
x = self.pool(F.relu(self.conv1(x)))
24+
x = self.pool(F.relu(self.conv2(x)))
25+
x = x.view(-1, 16 * 5 * 5)
26+
x = F.relu(self.fc1(x))
27+
x = F.relu(self.fc2(x))
28+
x = self.fc3(x)
29+
return x
30+
31+
def test_register_loss():
32+
"""Test that the loss is saved as a tensor."""
33+
out_dir = '/tmp/pytorch_test_loss'
34+
shutil.rmtree(out_dir, ignore_errors=True)
35+
36+
net = Net()
37+
criterion = nn.CrossEntropyLoss()
38+
optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9)
39+
40+
hook = ts.TornasoleHook(
41+
out_dir=out_dir,
42+
# With the default SaveConfig, the weights are not saved (only loss/gradient).
43+
# The weights tensors will be saved only at the final step, and only if they're a multiple
44+
# of save_interval. Issue with flushing?
45+
save_config=ts.SaveConfig(save_interval=1),
46+
)
47+
hook.register_hook(net)
48+
hook.register_loss(criterion) # This is the important line
49+
50+
batch_size = 1
51+
n_steps = 5
52+
# Use the same data at each step to test loss decreasing
53+
inputs, labels = torch.rand(batch_size, 3, 32, 32), torch.zeros(batch_size).long()
54+
for _ in range(n_steps):
55+
optimizer.zero_grad()
56+
outputs = net(inputs)
57+
loss = criterion(outputs, labels)
58+
loss.backward()
59+
optimizer.step()
60+
61+
#TODO(nieljare): Remove reliance on hook._cleanup()
62+
# What if the user has a training loop, then calls the Trials API in the same Python script
63+
# (like we do here). Then it'll crash, likewise in a Jupyter notebook.
64+
hook._cleanup()
65+
66+
trial = create_trial(path=out_dir, name='run')
67+
loss_coll = hook.collection_manager.get('losses')
68+
assert len(loss_coll.get_tensor_names()) == 3
69+
70+
loss_tensor = trial.tensor('CrossEntropyLoss_output_0')
71+
print(f"loss_tensor.steps() = {loss_tensor.steps()}")
72+
73+
gradient_tensor = trial.tensor('gradient/Net_fc1.weight')
74+
print(f"gradient_tensor.steps() = {gradient_tensor.steps()}")
75+
76+
weight_tensor = trial.tensor('Net_fc1.weight')
77+
print(f"weight_tensor.steps() = {weight_tensor.steps()}")
78+
79+
assert len(trial.available_steps()) == n_steps
80+
assert len(weight_tensor.steps()) == n_steps
81+
assert len(gradient_tensor.steps()) == n_steps
82+
assert len(loss_tensor.steps()) == n_steps
83+
assert loss_tensor.value(0) > loss_tensor.value(4)

tests/pytorch/test_simple_write.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def __init__(self, mode='weights-bias-gradients', to_save=[]):
4343
self.saved['relu2_input_0'] = dict()
4444
self.saved['fc3_input_0'] = dict()
4545
self.saved['Net_input_0'] = dict()
46-
self.saved['fc1_output0'] = dict()
47-
self.saved['relu1_output0'] = dict()
48-
self.saved['fc2_output0'] = dict()
49-
self.saved['relu2_output0'] = dict()
50-
self.saved['fc3_output0'] = dict()
51-
self.saved['Net_output0'] = dict()
46+
self.saved['fc1_output_0'] = dict()
47+
self.saved['relu1_output_0'] = dict()
48+
self.saved['fc2_output_0'] = dict()
49+
self.saved['relu2_output_0'] = dict()
50+
self.saved['fc3_output_0'] = dict()
51+
self.saved['Net_output_0'] = dict()
5252

5353

5454
def forward(self, x_in):
@@ -73,12 +73,12 @@ def forward(self, x_in):
7373
self.saved['fc3_input_0'][self.step] = relu2_out.data.numpy().copy()
7474
self.saved['Net_input_0'][self.step] = fc3_out.data.numpy().copy()
7575

76-
self.saved['fc1_output0'][self.step] = fc1_out.data.numpy().copy()
77-
self.saved['relu1_output0'][self.step] = relu1_out.data.numpy().copy()
78-
self.saved['fc2_output0'][self.step] = fc2_out.data.numpy().copy()
79-
self.saved['relu2_output0'][self.step] = relu2_out.data.numpy().copy()
80-
self.saved['fc3_output0'][self.step] = fc3_out.data.numpy().copy()
81-
self.saved['Net_output0'][self.step] = out.data.numpy().copy()
76+
self.saved['fc1_output_0'][self.step] = fc1_out.data.numpy().copy()
77+
self.saved['relu1_output_0'][self.step] = relu1_out.data.numpy().copy()
78+
self.saved['fc2_output_0'][self.step] = fc2_out.data.numpy().copy()
79+
self.saved['relu2_output_0'][self.step] = relu2_out.data.numpy().copy()
80+
self.saved['fc3_output_0'][self.step] = fc3_out.data.numpy().copy()
81+
self.saved['Net_output_0'][self.step] = out.data.numpy().copy()
8282
return out
8383

8484
# Create a tornasole hook. The initilization of hook determines which tensors
@@ -202,7 +202,7 @@ def saveall_test_helper(hook=None):
202202
weights = ['Net_fc1.weight', 'Net_fc2.weight', 'Net_fc3.weight']
203203
bias = ['Net_fc1.bias', 'Net_fc2.bias', 'Net_fc3.bias']
204204
inputs = ['fc1_input_0', 'relu1_input_0', 'fc2_input_0', 'relu2_input_0', 'fc3_input_0']
205-
outputs = ['fc1_output0', 'relu1_output0', 'fc2_output0', 'relu2_output0', 'fc3_output0']
205+
outputs = ['fc1_output_0', 'relu1_output_0', 'fc2_output_0', 'relu2_output_0', 'fc3_output_0']
206206
tensors = grads + bias + weights + inputs + outputs
207207

208208
assert len(trial.available_steps()) == len(save_steps)
@@ -237,7 +237,7 @@ def helper_test_multi_collections(hook, out_dir):
237237
weights = ['Net_fc1.weight', 'Net_fc2.weight', 'Net_fc3.weight']
238238
bias = ['Net_fc1.bias', 'Net_fc2.bias', 'Net_fc3.bias']
239239
inputs = ['fc1_input_0', 'relu1_input_0', 'relu2_input_0']
240-
outputs = ['fc1_output0', 'relu1_output0', 'relu2_output0']
240+
outputs = ['fc1_output_0', 'relu1_output_0', 'relu2_output_0']
241241
tensors = grads + bias + weights + inputs + outputs
242242

243243
assert len(trial.available_steps()) == len(save_steps)

tornasole/core/hook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class CallbackHook(BaseHook):
248248
__metaclass__ = ABCMeta
249249
INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
250250
INPUT_TENSOR_SUFFIX = '_input_'
251-
OUTPUT_TENSOR_SUFFIX = '_output'
251+
OUTPUT_TENSOR_SUFFIX = '_output_'
252252
GRADIENT_PREFIX = 'gradient/'
253253

254254
def __init__(self,

tornasole/core/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def get_worker_name_from_collection_file(filename):
108108
return re.match(worker_name_regex, filename).group(1)
109109

110110
def match_inc(tname, include):
111+
"""Matches anywhere in the string, doesn't require full match."""
111112
for inc in include:
112113
if re.search(inc, tname):
113114
return True

tornasole/mxnet/hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def forward_hook(self, block, inputs, outputs):
107107
return
108108

109109
block_name = block.name
110-
logger.debug("Processing the global step {0} for block {1}".format(self.step, block_name))
110+
# This overwhelms the logs; turn back on if you really need it
111+
# logger.debug("Processing the global step {0} for block {1}".format(self.step, block_name))
111112

112113
# Output input tensor
113114
self._write_inputs(block_name, inputs)

tornasole/pytorch/hook.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
CollectionKeys.WEIGHTS,
1515
CollectionKeys.BIASES,
1616
CollectionKeys.GRADIENTS,
17-
CollectionKeys.DEFAULT
17+
CollectionKeys.DEFAULT,
18+
CollectionKeys.LOSSES,
1819
]
1920

2021

@@ -68,9 +69,9 @@ def log_params(self, module):
6869
params = module.named_parameters()
6970
for name, param in params:
7071
pname = module_name + '_' + name
71-
self.logger.debug(
72-
"Processing the global step {0} for parameter {1}".format(
73-
self.step, pname))
72+
# This overwhelms the logs; turn back on if you really need it
73+
# self.logger.debug(
74+
# "Processing the global step {0} for parameter {1}".format(self.step, pname))
7475
self._write_tensor(tensor_name=pname, tensor_value=param.data)
7576

7677
# This hook is invoked by trainer prior to running the forward pass.
@@ -93,15 +94,16 @@ def forward_pre_hook(self, module, inputs):
9394
if self.last_saved_step is not None and not self.exported_collections:
9495
self.export_collections()
9596
self.exported_collections = True
96-
97+
9798
# This hook is invoked by trainer after running the forward pass.
9899
def forward_hook(self, module, inputs, outputs):
99100
if self.collections_in_this_step is None:
100101
logging.debug("Skipping the global step {0}".format(self.step))
101102
return
102103

103104
module_name = self.module_maps[module]
104-
logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name))
105+
# This overwhelms the logs; turn back on if you really need it
106+
# logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name))
105107

106108
# Output input tensor
107109
self._write_inputs(module_name, inputs)
@@ -119,39 +121,61 @@ def back(grad):
119121
self._write_tensor(tensor_name=self.GRADIENT_PREFIX + tname, tensor_value=grad)
120122
return back
121123

122-
def _recursive_apply(self, module):
123-
"""
124-
This function is "applied" to every child in the block. This function in turn
125-
registers the forward hook to each module. It helps logging the input output tensors
126-
of that module.
127-
"""
128-
module.register_forward_hook(self.forward_hook)
129-
130124
def _backward_apply(self, module):
125+
"""Apply the function `self.backward_hook` as a callback to each parameter in `module.
126+
127+
This will capture the gradients.
128+
"""
131129
params = module.named_parameters()
132130
for name, param in params:
133131
pname = module._get_name() + '_' + name
134132
param.register_hook(self.backward_hook(pname))
135133

134+
def closure_for_registering_forward_hook(self, module):
135+
"""Lambda functions don't work here."""
136+
module.register_forward_hook(self.forward_hook)
137+
136138
def register_hook(self, module):
137139
"""
138140
This function registers the forward hook. If user wants to register the hook
139141
for every child in the given block, then the function calls "apply" API for
140142
registration of the hook.
141-
The hook is registered recursively for all blocks
143+
The hook is registered recursively for all blocks.
142144
"""
145+
# Typechecking
143146
if not isinstance(module, torch.nn.Module):
144-
logger.error("The given module type {0} is not currently supported by Tornasole Hook".format(
145-
module.__class__.__name__))
146-
return
147-
module.register_forward_pre_hook(self.forward_pre_hook)
147+
raise ValueError(f"Module type {module.__class__.__name__} must be type torch.nn.Module")
148148

149-
for layer in list(module.named_modules()):
150-
self.module_maps[layer[1]] = layer[0]
149+
# Create a mapping from modules to their names
150+
for name, submodule in module.named_modules():
151+
assert submodule not in self.module_maps, f"Don't register module={module} twice"
152+
self.module_maps[submodule] = name
151153
self.module_maps[module] = module._get_name()
152-
module.apply(self._recursive_apply)
154+
155+
# Use `forward_pre_hook` for the entire net
156+
module.register_forward_pre_hook(self.forward_pre_hook)
157+
158+
# Set `self.forward_hook` as a callback for each submodule/layer.
159+
# `module.apply(fn)` calls fn for each submodule in module.children()
160+
module.apply(self.closure_for_registering_forward_hook)
161+
162+
# Capture the gradient for each parameter in the net
153163
self._backward_apply(module)
154164

165+
def register_loss(self, loss_module):
166+
"""Register something like `criterion = nn.CrossEntropyLoss()`."""
167+
# Typechecking
168+
assert isinstance(loss_module, torch.nn.modules.loss._Loss), (
169+
f"loss_module={loss_module} must be subclass of `torch.nn.modules.loss._Loss`, "
170+
f"but has class hierarchy {type.mro(type(loss_module))}"
171+
)
172+
# Register the module in self.module_maps
173+
name = loss_module._get_name()
174+
self.module_maps[loss_module] = name
175+
# Add a callback to the forward pass
176+
loss_module.register_forward_hook(self.forward_hook)
177+
178+
155179
@staticmethod
156180
def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):
157181
return get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs)

tornasole/pytorch/torch_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _register_default_collections(self):
2424
self.get(CollectionKeys.WEIGHTS).include('^(?!gradient).*weight')
2525
self.get(CollectionKeys.BIASES).include('^(?!gradient).*bias')
2626
self.get(CollectionKeys.GRADIENTS).include('^gradient')
27-
self.get(CollectionKeys.LOSSES).include('.*loss')
27+
self.get(CollectionKeys.LOSSES).include('Loss')
2828

2929
def create_collection(self, name):
3030
super().create_collection(name, cls=Collection)

0 commit comments

Comments
 (0)