Skip to content

Commit cfcfd7a

Browse files
vandanavkVikas-kum
authored andcommitted
Skip logging the input tensors to the loss block (aws#86)
* Skip logging the input tensors to the loss block * Add loss inputs for PT functional loss * append mode name * Write correct mode for each scalar in write_scalars() * Increment global mode step number irrespective of which mode it is
1 parent adf9a8b commit cfcfd7a

File tree

6 files changed

+180
-63
lines changed

6 files changed

+180
-63
lines changed

smdebug/core/hook.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@
4747

4848

4949
class ScalarCache(object):
50-
def __init__(self, scalar_name, scalar_val, sm_metric, write_tb, write_event):
50+
def __init__(self, scalar_name, scalar_val, mode, sm_metric, write_tb, write_event):
5151
self.name = scalar_name
5252
self.value = scalar_val
53+
self.mode = mode
5354
self.sm_metric = sm_metric
5455
self.write_tb = write_tb
5556
self.write_event = write_event
@@ -440,6 +441,10 @@ def _increment_step(self):
440441

441442
self.step += 1
442443
self.mode_steps[self.mode] += 1
444+
445+
# Increment Global step number irrespective of what mode it is
446+
if self.mode != ModeKeys.GLOBAL:
447+
self.mode_steps[ModeKeys.GLOBAL] = self.step
443448
self._collections_to_save_for_step = None
444449

445450
def _write_state(self):
@@ -564,12 +569,15 @@ def _write_scalars(self):
564569
for scalar_obj in self.scalar_cache:
565570
scalar_name = scalar_obj.name
566571
scalar_val = scalar_obj.value
572+
scalar_mode = scalar_obj.mode
567573
sm_metric = scalar_obj.sm_metric
568574
write_tb = scalar_obj.write_tb
569575
write_event = scalar_obj.write_event
570576
if self.metrics_writer and sm_metric:
571577
self.metrics_writer.log_metric(
572-
scalar_name, scalar_val, iteration_number=self.mode_steps[self.mode]
578+
scalar_name + "_" + scalar_mode.name,
579+
scalar_val,
580+
iteration_number=self.mode_steps[scalar_mode],
573581
)
574582
if write_tb:
575583
tb_writer = self._maybe_get_tb_writer()
@@ -596,7 +604,7 @@ def save_scalar(self, name, value, sm_metric=False):
596604
val = self._make_numpy_array(value)
597605
if val.size != 1:
598606
raise TypeError(f"{name} has non scalar value of type: {type(value)}")
599-
scalar_obj = ScalarCache(name, val, sm_metric=True, write_tb=True, write_event=True)
607+
scalar_obj = ScalarCache(name, val, self.mode, sm_metric, write_tb=True, write_event=True)
600608
self.scalar_cache.append(scalar_obj)
601609

602610
def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
@@ -657,7 +665,12 @@ def _save_for_tensor(self, tensor_name, tensor_value, check_before_write=True):
657665
# Always log loss to Minerva
658666
tensor_val = np.mean(np_val)
659667
scalar_obj = ScalarCache(
660-
tensor_name, tensor_val, sm_metric=True, write_tb=False, write_event=False
668+
tensor_name,
669+
tensor_val,
670+
self.mode,
671+
sm_metric=True,
672+
write_tb=False,
673+
write_event=False,
661674
)
662675
self.scalar_cache.append(scalar_obj)
663676

smdebug/mxnet/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _register_default_collections(self):
2424
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
2525
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
2626
self.get(CollectionKeys.GRADIENTS).include("^gradient")
27-
self.get(CollectionKeys.LOSSES).include(".*loss")
27+
self.get(CollectionKeys.LOSSES).include(".*loss._(?!input).*output")
2828

2929
def create_collection(self, name):
3030
super().create_collection(name, cls=Collection)

smdebug/pytorch/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _register_default_collections(self):
3939
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
4040
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
4141
self.get(CollectionKeys.GRADIENTS).include("^gradient")
42-
self.get(CollectionKeys.LOSSES).include("[Ll]oss")
42+
self.get(CollectionKeys.LOSSES).include("[Ll]oss_(?!input).*output")
4343

4444
def create_collection(self, name):
4545
super().create_collection(name, cls=Collection)

0 commit comments

Comments
 (0)