Skip to content

Commit 91eb151

Browse files
authored
tensor_names() returns only tensornames (aws#159)
* Fixes issue: awslabs/sagemaker-debugger#113 * bug fix in tensor_names. won't include extra reduction names.
1 parent e05c490 commit 91eb151

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

smdebug/trials/trial.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,24 +289,32 @@ def _populate_mode_to_tensor_name_map(self, tensor: TensorLocation) -> None:
289289
self.mode_to_tensors_map[tensor.mode].add(tensor.tensorname)
290290

291291
def _add_tensor(self, step_num, worker, tensor_object: TensorLocation):
292-
to = tensor_object
293-
# self.worker_set.add(worker)
294-
if REDUCTIONS_PREFIX in to.tensorname:
295-
tname, red_name, abs = reverse_reduction_tensor_name(to.tensorname)
292+
is_reduction = False
293+
294+
if REDUCTIONS_PREFIX in tensor_object.tensorname:
295+
tname, red_name, abs = reverse_reduction_tensor_name(tensor_object.tensorname)
296+
tensor_object.tensorname = tname
297+
is_reduction = True
296298
else:
297-
tname = to.tensorname
299+
tname = tensor_object.tensorname
300+
298301
if tname not in self._tensors:
299-
t = Tensor(tname, trial=self, cache=self.cache)
300-
self._tensors[tname] = t
301-
t = self._tensors[tname]
302-
self._populate_step_dict(to, step_num)
303-
self._populate_global_step_to_tensor_name_map(to, step_num)
304-
self._populate_workers_for_global_step(step_num, worker)
305-
self._populate_mode_to_tensor_name_map(to)
306-
if REDUCTIONS_PREFIX in to.tensorname:
307-
t.add_reduction_step(to.mode, to.mode_step, worker, red_name, abs, to)
302+
tensor = Tensor(tname, trial=self, cache=self.cache)
303+
self._tensors[tname] = tensor
304+
305+
tensor = self._tensors[tname]
306+
307+
if is_reduction:
308+
tensor.add_reduction_step(
309+
tensor_object.mode, tensor_object.mode_step, worker, red_name, abs, tensor_object
310+
)
308311
else:
309-
t.add_step(to.mode, to.mode_step, worker, to)
312+
tensor.add_step(tensor_object.mode, tensor_object.mode_step, worker, tensor_object)
313+
314+
self._populate_step_dict(tensor_object, step_num)
315+
self._populate_global_step_to_tensor_name_map(tensor_object, step_num)
316+
self._populate_workers_for_global_step(step_num, worker)
317+
self._populate_mode_to_tensor_name_map(tensor_object)
310318

311319
def _tensors_matching_regex(self, regex_list) -> set:
312320
matched_tensornames = set()

tests/tensorflow/hooks/test_reductions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def helper_test_reductions(trial_dir, hook, save_raw_tensor):
1717

1818
tr = create_trial(trial_dir)
1919
assert len(tr.tensor_names()) == 3, tr.tensor_names()
20+
for step in tr.steps():
21+
assert len(tr.tensor_names(step=step)) == 3, tr.tensor_names()
2022
for tname in tr.tensor_names():
2123
t = tr.tensor(tname)
2224
if tname in tr.tensor_names(collection="losses"):

0 commit comments

Comments
 (0)