Skip to content

Commit 8ad99b6

Browse files
authored
Skip logging the input tensors to the loss block. (aws#64)
1 parent 137b4bf commit 8ad99b6

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

smdebug/mxnet/hook.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ def forward_hook(self, block, inputs, outputs):
154154
# This overwhelms the logs; turn back on if you really need it
155155
# logger.debug("Processing the global step {0} for block {1}".format(self.step, block_name))
156156

157-
# Output input tensor
158-
self._write_inputs(block_name, inputs)
157+
# Output input tensor if it is not a loss block
158+
if isinstance(block, mx.gluon.loss.Loss) is False:
159+
self._write_inputs(block_name, inputs)
159160

160161
# Output output tensors
161162
self._write_outputs(block_name, outputs)

tests/mxnet/test_hook_loss_collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def test_loss_collection_default():
3333
loss_val = loss_tensor.value(step_num=1)
3434
assert len(loss_val) > 0
3535

36+
# Assert that we are not logging the inputs to loss block.
37+
input_loss_tensors = tr.tensor_names(regex=".*loss._input*")
38+
assert len(input_loss_tensors) == 0
3639
shutil.rmtree(out_dir)
3740

3841

0 commit comments

Comments
 (0)