Skip to content

Commit c9eb769

Browse files
authored
Add ability to only save shapes of tensors (aws#328)
1 parent 47ceaf0 commit c9eb769

22 files changed

+681
-205
lines changed

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ include_workers
9696
include_regex
9797
reductions
9898
save_raw_tensor
99+
save_shape
99100
save_interval
100101
save_steps
101102
start_step

smdebug/core/hook.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
size_and_shape,
4747
validate_custom_tensor_value,
4848
)
49-
from smdebug.core.writer import FileWriter
49+
from smdebug.core.writer import FileWriter, ShapeWriter
5050
from smdebug.exceptions import InvalidCollectionConfiguration
5151

5252
try:
@@ -222,7 +222,7 @@ def __init__(
222222
self.mode = ModeKeys.GLOBAL
223223
self.mode_steps = {ModeKeys.GLOBAL: init_step}
224224
self.writer = None
225-
225+
self.shape_writer = None
226226
if is_sagemaker_job() and SageMakerFileMetricsWriter is not None:
227227
self.metrics_writer = SageMakerFileMetricsWriter()
228228
else:
@@ -343,6 +343,12 @@ def _get_collections_to_save_for_step(self) -> Set["Collection"]:
343343
)
344344
return self._collections_to_save_for_step
345345

346+
def _saving_shapes_in_step(self) -> bool:
347+
for coll in self._get_collections_to_save_for_step():
348+
if coll.reduction_config.save_shape is True:
349+
return True
350+
return False
351+
346352
def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]:
347353
self._assert_prep()
348354
# for tf this will be prepopulated in check_and_add_tensor
@@ -404,6 +410,17 @@ def _prepare_collections(self):
404410
self.prepared_collections = True
405411

406412
#### End of Save Manager methods ####
413+
@staticmethod
414+
def _close_given_writer_map(writer_dict):
415+
# Delete all the dist training writers
416+
to_delete_writers = []
417+
for key, writer in writer_dict.items():
418+
# close calls flush
419+
writer.close()
420+
to_delete_writers.append(key)
421+
422+
for key in to_delete_writers:
423+
del writer_dict[key]
407424

408425
def _close_writers(self) -> None:
409426
if self.dry_run:
@@ -417,16 +434,11 @@ def _close_writers(self) -> None:
417434
self.writer.close()
418435
self.writer = None
419436

420-
to_delete_writers = []
437+
self._close_given_writer_map(self.tb_writers)
421438

422-
# Delete all the tb writers
423-
for mode, writer in self.tb_writers.items():
424-
if writer is not None:
425-
writer.flush()
426-
writer.close()
427-
to_delete_writers.append(mode)
428-
for mode in to_delete_writers:
429-
del self.tb_writers[mode]
439+
if self.shape_writer is not None:
440+
self.shape_writer.close()
441+
self.shape_writer = None
430442

431443
def _initialize_writers(self, only_initialize_if_missing=False) -> None:
432444
# Function is overridden in smdebug/tensorflow/base_hook.py
@@ -454,17 +466,32 @@ def _initialize_writers(self, only_initialize_if_missing=False) -> None:
454466
if self.save_all_workers is False:
455467
if self.worker != self.chief_worker:
456468
return
469+
457470
self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)
458471

459-
def _get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
472+
if self._saving_shapes_in_step():
473+
self.shape_writer = ShapeWriter(
474+
trial_dir=self.out_dir,
475+
step=self.step,
476+
worker=self.worker,
477+
index_writer=self.writer.index_writer,
478+
)
479+
480+
def _get_single_process_writers(self, shape_writers=False) -> List[FileWriter]:
481+
if shape_writers is False:
482+
return [self.writer] if self.writer else []
483+
else:
484+
return [self.shape_writer] if self.shape_writer else []
485+
486+
def _get_writers(self, tensor_name, tensor_ref=None, shape_writers=False) -> List[FileWriter]:
460487
"""
461488
:param tensor_name:
462489
:param tensor_ref: used by TF
463490
:return: List[FileWriter]
464491
"""
465492
if self.save_all_workers is False and self.worker != self.chief_worker:
466493
return []
467-
return [self.writer] if self.writer else []
494+
return self._get_single_process_writers(shape_writers)
468495

469496
def _maybe_get_tb_writer(self) -> Optional[FileWriter]:
470497
""" Returns a FileWriter object if `hook.tensorboard_dir` has been specified, else None.
@@ -726,6 +753,28 @@ def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_
726753
self._write_raw_tensor_simple(tensor_name, tensor_value, tensor_ref=tensor_ref)
727754
break
728755

756+
def _write_shape(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
757+
shape_writers = self._get_writers(tensor_name, tensor_ref=tensor_ref, shape_writers=True)
758+
for s_col in save_collections:
759+
reduction_config = s_col.reduction_config
760+
if self.dry_run is False and reduction_config.save_shape is True:
761+
numpy_tensor_value = self._make_numpy_array(tensor_value)
762+
this_size, this_shape = size_and_shape(numpy_tensor_value)
763+
if tensor_ref is not None and tensor_ref.tf_obj is not None:
764+
original_name = tensor_ref.tf_obj.name
765+
else:
766+
original_name = None
767+
768+
for writer in shape_writers:
769+
writer.write_shape(
770+
tensor_name,
771+
this_shape,
772+
self.mode,
773+
self.mode_steps[self.mode],
774+
original_name=original_name,
775+
)
776+
break
777+
729778
def _write_raw_tensor_simple(self, tensor_name, tensor_value, tensor_ref=None, timestamp=None):
730779
# tensor_ref is used by TF
731780
# todo: if fp16, check perf of saving as fp16 in proto vs as fp32
@@ -805,6 +854,9 @@ def _write_for_tensor(self, tensor_name, tensor_value, save_collections, tensor_
805854
:param save_collections: list of collections which are being saved for this step
806855
"""
807856
self._log_save(tensor_name, save_collections)
857+
858+
self._write_shape(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)
859+
808860
# write reductions defined for collections this tensor may be part of
809861
self._write_reductions(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)
810862

smdebug/core/index_reader.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MISSING_EVENT_FILE_RETRY_LIMIT,
1717
MISSING_EVENT_FILE_RETRY_LIMIT_KEY,
1818
)
19-
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation
19+
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation, TensorShape
2020
from smdebug.core.logger import get_logger
2121
from smdebug.core.modes import ModeKeys
2222
from smdebug.core.s3_utils import list_s3_objects
@@ -120,12 +120,22 @@ def fetch_tensor_value(self, tensor_location: TensorLocation):
120120
def list_event_files(self, start_after_prefix):
121121
pass
122122

123-
@abstractmethod
124123
def load_tensor_data_from_index_files(
125124
self, start_after_key=None, range_steps=None
126125
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
127126
"""Return a triply nested dict referring to tensor data."""
128127

128+
responses, steps, last_index_token, workers = self.read_index_files(
129+
start_after_key, range_steps
130+
)
131+
132+
tensor_data = {}
133+
for step, response, worker in zip(steps, responses, workers):
134+
tensor_data = self._update_tensors_from_json(
135+
tensor_data, step, response, self.path, worker
136+
)
137+
return tensor_data, last_index_token
138+
129139
@abstractmethod
130140
def _is_event_file_present(self, file_name) -> bool:
131141
pass
@@ -203,8 +213,10 @@ def _validate(index_dict):
203213
raise IndexReaderException("meta section is not present")
204214
if len(index_dict["meta"]) == 0:
205215
raise IndexReaderException("meta section is empty")
206-
if "tensor_payload" not in index_dict:
207-
raise IndexReaderException("tensor_payload section is not present")
216+
if "tensor_payload" not in index_dict and "shape_payload" not in index_dict:
217+
raise IndexReaderException(
218+
"neither tensor_payload nor shape_payload sections are present"
219+
)
208220

209221
def _update_tensors_from_json(
210222
self, index_tensors_dict, step, response: bytes, path, worker
@@ -233,28 +245,41 @@ def _update_tensors_from_json(
233245
mode = index_meta["mode"]
234246
mode = ModeKeys[mode.strip()]
235247
mode_step = index_meta["mode_step"]
236-
event_file_name = os.path.join(path, index_meta["event_file_name"])
237-
tensors = index_dict["tensor_payload"]
238-
for tensor in tensors:
239-
tensor_name = tensor["tensorname"]
240-
start_idx = tensor["start_idx"]
241-
length = tensor["length"]
242-
tensor_location = TensorLocation(
243-
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
244-
)
248+
249+
to_update_index_dict = []
250+
251+
if "tensor_payload" in index_dict and len(index_dict["tensor_payload"]):
252+
event_file_name = os.path.join(path, index_meta["event_file_name"])
253+
for tensor in index_dict["tensor_payload"]:
254+
tensor_name = tensor["tensorname"]
255+
start_idx = tensor["start_idx"]
256+
length = tensor["length"]
257+
tensor_location = TensorLocation(
258+
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
259+
)
260+
to_update_index_dict.append((tensor_name, step, tensor_location))
261+
262+
if "shape_payload" in index_dict and len(index_dict["shape_payload"]):
263+
for tensor in index_dict["shape_payload"]:
264+
tensor_name = tensor["tensorname"]
265+
original_name = tensor["originalname"]
266+
shape = tensor["shape"]
267+
ts = TensorShape(tensor_name, mode, mode_step, shape, original_name)
268+
to_update_index_dict.append((tensor_name, step, ts))
269+
270+
for tu in to_update_index_dict:
271+
tensor_name, step, obj = tu
272+
if isinstance(obj, TensorLocation):
273+
obj_dict = {"tensor_location": obj}
274+
elif isinstance(obj, TensorShape):
275+
obj_dict = {"tensor_shape": obj}
245276
if tensor_name in index_tensors_dict:
246277
if step in index_tensors_dict[tensor_name]:
247-
index_tensors_dict[tensor_name][step].update(
248-
{worker: {"tensor_location": tensor_location}}
249-
)
278+
index_tensors_dict[tensor_name][step].update({worker: obj_dict})
250279
else:
251-
index_tensors_dict[tensor_name].update(
252-
{step: {worker: {"tensor_location": tensor_location}}}
253-
)
280+
index_tensors_dict[tensor_name].update({step: {worker: obj_dict}})
254281
else:
255-
index_tensors_dict[tensor_name] = {
256-
step: {worker: {"tensor_location": tensor_location}}
257-
}
282+
index_tensors_dict[tensor_name] = {step: {worker: obj_dict}}
258283
return index_tensors_dict
259284

260285

@@ -285,22 +310,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
285310
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
286311
return tensor_data
287312

288-
def load_tensor_data_from_index_files(
289-
self, start_after_key=None, range_steps=None
290-
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
291-
"""Return a triply nested dict referring to tensor data."""
292-
293-
responses, steps, last_index_token, workers = self.read_index_files(
294-
start_after_key, range_steps
295-
)
296-
297-
tensor_data = {}
298-
for step, response, worker in zip(steps, responses, workers):
299-
tensor_data = self._update_tensors_from_json(
300-
tensor_data, step, response, self.path, worker
301-
)
302-
return tensor_data, last_index_token
303-
304313
def read_index_files(
305314
self, start_after_key: str, range_steps=None
306315
) -> Tuple[List[bytes], list, str, List[str]]:
@@ -398,21 +407,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
398407
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
399408
return tensor_data
400409

401-
def load_tensor_data_from_index_files(
402-
self, start_after_key=None, range_steps=None
403-
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
404-
"""Return a triply nested dict referring to tensor data."""
405-
406-
responses, steps, last_index_token, workers = self.read_index_files(
407-
start_after_key, range_steps
408-
)
409-
tensor_data = {}
410-
for step, response, worker in zip(steps, responses, workers):
411-
tensor_data = self._update_tensors_from_json(
412-
tensor_data, step, response, self.path, worker
413-
)
414-
return tensor_data, last_index_token
415-
416410
def read_index_files(
417411
self, start_after_key: str, range_steps=None
418412
) -> Tuple[List[bytes], list, str, List[str]]:

smdebug/core/locations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ def to_dict(self):
2424
return {"tensorname": self.tensorname, "start_idx": self.start_idx, "length": self.length}
2525

2626

27+
class TensorShape:
28+
def __init__(self, name, mode, mode_step, shape, original_name=None):
29+
if original_name is None:
30+
original_name = name
31+
self.name = name
32+
self.original_name = original_name
33+
self.mode = mode
34+
self.mode_step = mode_step
35+
self.shape = tuple(shape)
36+
37+
def to_dict(self):
38+
return {"tensorname": self.name, "originalname": self.original_name, "shape": self.shape}
39+
40+
2741
STEP_NUMBER_FORMATTING_LENGTH = "012"
2842

2943

0 commit comments

Comments
 (0)