Skip to content

Commit f23e89c

Browse files
authored
Function to Test If the hook has been configured with the Default hook config (aws#332)
1 parent 0850ee1 commit f23e89c

13 files changed

+195
-60
lines changed

smdebug/core/config_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@
4141

4242
CALLABLE_CACHE_ENV_VAR = "SMDEBUG_KERAS_CALLABLE_CACHE_TYPE"
4343
DEFAULT_CALLABLE_CACHE = "CACHE_PER_MODE"
44+
45+
DEFAULT_SAVED_COLLECTIONS = ["losses"]

smdebug/core/hook.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from smdebug.core.collection_manager import CollectionManager
2424
from smdebug.core.config_constants import (
25+
DEFAULT_SAVED_COLLECTIONS,
2526
DEFAULT_WORKER_NAME,
2627
LATEST_GLOBAL_STEP_SAVED,
2728
LATEST_GLOBAL_STEP_SEEN,
@@ -343,6 +344,13 @@ def _get_collections_to_save_for_step(self) -> Set["Collection"]:
343344
)
344345
return self._collections_to_save_for_step
345346

347+
def is_tensor_saved_for_step(self, tensor_name):
348+
collections_to_save = self._get_collections_to_save_for_step()
349+
for c in collections_to_save:
350+
if match_inc(tensor_name, c.include_regex):
351+
return True
352+
return False
353+
346354
def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]:
347355
self._assert_prep()
348356
# for tf this will be prepopulated in check_and_add_tensor
@@ -364,6 +372,14 @@ def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]:
364372
def _get_default_collections(self):
365373
pass
366374

375+
def has_default_hook_configuration(self):
376+
# Used in the internal framework forks to determine if the hook
377+
# is using the default hook configuration
378+
collections_being_saved = [x.name for x in self._collections_to_save]
379+
if set(collections_being_saved) == set(DEFAULT_SAVED_COLLECTIONS):
380+
return True
381+
return False
382+
367383
def _prepare_collections(self):
368384
"""Populate collections_to_save and ensure every collection has
369385
a save_config and reduction_config."""
@@ -525,6 +541,13 @@ def _increment_step(self):
525541
self.mode_steps[ModeKeys.GLOBAL] = self.step
526542
self._collections_to_save_for_step = None
527543

544+
# Called in the internal AWS codebase to determine
545+
# if a particular tensor value should be saved
546+
def should_save_tensor_or_collection(self, tensor_name: str, collection_name: str) -> bool:
547+
if self._is_collection_being_saved_for_step(collection_name):
548+
return True
549+
return self.is_tensor_saved_for_step(tensor_name)
550+
528551
def _write_state(self):
529552
if self.state_store.is_checkpoint_updated():
530553
current_state = dict()

smdebug/tensorflow/base_hook.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
# Local
2222
from .collection import CollectionKeys, CollectionManager
23+
from .constants import TF_DEFAULT_SAVED_COLLECTIONS
2324
from .singleton_utils import set_hook
2425
from .utils import (
2526
TFDistributionStrategy,
@@ -217,6 +218,14 @@ def export_collections(self):
217218
collection_file_name = f"{self.worker}_collections.json"
218219
self.collection_manager.export(self.out_dir, collection_file_name)
219220

221+
def has_default_hook_configuration(self):
222+
# Used in AWS TF to determine if the hook
223+
# is using the default hook configuration
224+
collections_being_saved = [x.name for x in self._collections_to_save]
225+
if set(collections_being_saved) == set(TF_DEFAULT_SAVED_COLLECTIONS):
226+
return True
227+
return False
228+
220229
def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["Collection"]]:
221230
if self._custom_collections is None:
222231
self._custom_collections = set()

smdebug/tensorflow/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
SMDEBUG_GRADIENTS_KEY = "smdebug_gradients"
22
SMDEBUG_LAYER_OUTPUTS_KEY = "smdebug_layer_outputs"
33
SMDEBUG_PREFIX = "smdebug_"
4+
5+
TF_DEFAULT_SAVED_COLLECTIONS = ["losses", "metrics", "sm_metrics"]

smdebug/tensorflow/keras.py

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
get_model_input_export_name,
3131
get_model_output_export_name,
3232
is_keras_optimizer,
33+
is_tf_version_2_3_x,
3334
is_tf_version_2x,
3435
)
3536

@@ -71,6 +72,14 @@ def __init__(
7172
) # stores tensors custom tensors saved by users every step
7273
self.saved_layers = dict()
7374
self.has_registered_model = False
75+
# supports_tf_logs property was introduced in TF 2.3.0
76+
# it indicates to the framework that the callback is not
77+
# limited to reading only numpy logs
78+
self._supports_tf_logs = True
79+
# TF 2.3.0 has a callback ordering bug
80+
# this flag indicated to the train_batch_begin callback
81+
# the the step was already incremented in the on_train_begin callback
82+
self.step_incremented_in_on_train_begin = False
7483

7584
def _is_not_supported(self):
7685
if self.distribution_strategy is None:
@@ -109,7 +118,8 @@ def register_model(self, model):
109118
# It attaches a hook to every layer of the model to capture
110119
# layer values
111120
self.model = model
112-
self._wrap_model_with_input_output_saver()
121+
if self.tape is not None:
122+
self._wrap_model_with_input_output_saver()
113123
self.has_registered_model = True
114124

115125
def _get_matching_collections(
@@ -348,7 +358,10 @@ def _prepare_tensors_available_post_step(self):
348358

349359
# Add tensor to custom collections
350360
for custom_coll in custom_collections:
351-
if match_inc(tensor_ref.name, custom_coll.include_regex):
361+
if (
362+
match_inc(tensor_ref.name, custom_coll.include_regex)
363+
and tensor_ref.tf_obj is not None
364+
):
352365
custom_coll.add_for_mode(tensor_ref.tf_obj, self.mode)
353366
if custom_coll not in self.tensor_to_collections[tensor_ref.name]:
354367
self.tensor_to_collections[tensor_ref.name].add(custom_coll)
@@ -390,6 +403,12 @@ def _save_custom_tensors_post_step(self):
390403
self._save_tensor_to_file(tensor_name, tensor_value, collection_names)
391404
self.custom_tensors_to_save.clear()
392405

406+
def should_save_layer(self, layer_name):
407+
# Called in AWS TF to determine
408+
# if a particular layer value
409+
# should be saved
410+
return self.should_save_tensor_or_collection(layer_name, CollectionKeys.LAYERS)
411+
393412
def _save_tensor_to_file(self, tensor_name, tensor_value, collections):
394413
if isinstance(collections, set) is False:
395414
collections = {collections}
@@ -418,6 +437,31 @@ def _save_tensor_to_file(self, tensor_name, tensor_value, collections):
418437
collection.set_tensor_ref(tensor_ref)
419438
self._save_for_tensor(tensor_name, t, check_before_write=True)
420439

440+
def save_gradients_from_logs(self, gradients):
441+
if gradients is not None:
442+
gradient_collection = self.get_collection(CollectionKeys.GRADIENTS)
443+
step_collections = self._get_collections_to_save_for_step()
444+
collections_to_write = (
445+
{gradient_collection} if gradient_collection in step_collections else set()
446+
)
447+
if gradients and isinstance(gradients[0], tuple) is False:
448+
gradients = zip(self.model.trainable_variables, gradients)
449+
for v, g in gradients:
450+
if isinstance(v, tf.Tensor):
451+
# Tensor.name is meaningless with eager execution
452+
layer_name = str(v.numpy(), "utf-8")
453+
elif isinstance(v, tf.Variable):
454+
layer_name = v.name
455+
else:
456+
layer_name = v
457+
layer_name = layer_name.split(":")[0]
458+
export_name = "gradients/" + layer_name + "Grad"
459+
if isinstance(g, IndexedSlices):
460+
# This class is a simple wrapper for a pair of Tensor objects
461+
# See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices
462+
g = g.values
463+
self._save_tensor_to_file(export_name, g, collections_to_write)
464+
421465
def save_smdebug_logs(self, logs):
422466
if logs is None:
423467
return
@@ -437,24 +481,10 @@ def save_smdebug_logs(self, logs):
437481
)
438482
# Save Gradients
439483
elif key == SMDEBUG_GRADIENTS_KEY:
440-
gradients = logs[key]
441-
if gradients is not None:
442-
for g, v in zip(gradients, self.model.trainable_variables):
443-
layer_name = v.name
444-
if len(layer_name.split(":")) > 1:
445-
layer_name = layer_name.split(":")[0]
446-
export_name = "gradients/" + layer_name + "Grad"
447-
if isinstance(g, IndexedSlices):
448-
# This class is a simple wrapper for a pair of Tensor objects
449-
# See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices
450-
g = g.values
451-
tensors_to_save.append((export_name, g))
452-
collections_to_write = {self.get_collection(CollectionKeys.GRADIENTS)}
484+
self.save_gradients_from_logs(logs[key])
453485
# Save Intermediate Layers
454486
elif key == SMDEBUG_LAYER_OUTPUTS_KEY:
455-
layer_outputs = logs[key]
456-
self.save_layer_outputs(layer_outputs)
457-
self.save_layer_inputs(logs[ModelInput.INPUTS], layer_outputs)
487+
self._save_layer_values(logs[key])
458488
# Save Model Inputs
459489
elif key in ModelInputs:
460490
export_name = get_model_input_export_name()
@@ -489,10 +519,9 @@ def _save_metrics(self, batch, logs, force_save=False):
489519
self._add_metric(metric_name=key)
490520
self._save_for_tensor(key, logs[key], check_before_write=False)
491521

492-
def _save_layer_input_and_outputs(self, grad_tape=False):
493-
# Iterates over all the saved layers for input and output values
494-
if is_tf_version_2x() is False or (grad_tape is False and self.model.run_eagerly is False):
495-
# This function only works when the run_eagerly is True
522+
def _save_layer_input_and_outputs(self):
523+
# Run only for GradTape
524+
if self.tape is None:
496525
return
497526
for layer_name in self.saved_layers:
498527
# Save Input
@@ -520,7 +549,6 @@ def _save_tensors_post_step(self, batch, logs):
520549
# weights, metrics
521550
self._save_metrics(batch, logs)
522551
self.save_smdebug_logs(logs)
523-
self._save_layer_input_and_outputs()
524552
self._save_custom_tensors_post_step()
525553

526554
if is_tf_version_2x() and tf.executing_eagerly():
@@ -615,6 +643,13 @@ def _on_any_mode_begin(self, mode):
615643
self.graph = tf.get_default_graph()
616644
self.set_mode(mode)
617645

646+
if self.prepared_collections is False and is_tf_version_2_3_x():
647+
# Addresses ordering issues in TF 2.3.0
648+
# sets prepared_collections to True here
649+
self._prepare_collections()
650+
self._increment_step()
651+
self.step_incremented_in_on_train_begin = True
652+
618653
# have to clear callable cache if we are not caching per mode
619654
self.callable_cache.change_mode()
620655

@@ -658,7 +693,12 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
658693
# Write the gradients of the past step if the writer is still available.
659694
if self.writer is not None or len(self.writer_map):
660695
self._close_writers()
661-
self._increment_step()
696+
697+
# Addresses callback ordering bug in TF 2.3.0
698+
if self.step_incremented_in_on_train_begin is False:
699+
self._increment_step()
700+
else:
701+
self.step_incremented_in_on_train_begin = False
662702

663703
if self.prepared_collections is False:
664704
# sets prepared_collections to True here
@@ -668,7 +708,6 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
668708
if (is_tf_version_2x() and tf.executing_eagerly()) or self._validate_exec_function(
669709
self._get_exec_function(mode)
670710
):
671-
self._wrap_model_with_input_output_saver()
672711
self._prepare_layers(mode)
673712
self._prepare_tensors_available_post_step()
674713
self._prepared_tensors[mode] = True
@@ -698,33 +737,23 @@ def on_test_batch_begin(self, batch, logs=None):
698737
def on_predict_batch_begin(self, batch, logs=None):
699738
self._on_any_batch_begin(batch, ModeKeys.PREDICT, logs=logs)
700739

701-
def _save_layer_values(self, layer_outputs, collection, model=None, inputs=None):
702-
if model is None:
703-
if self.model:
704-
model = self.model
705-
else:
706-
return
707-
if layer_outputs is not None:
708-
tensors_to_save = []
709-
step_collections = self._get_collections_to_save_for_step()
710-
collections_to_write = {collection} if collection in step_collections else set()
711-
tensor_suffix = "output"
712-
if inputs is not None:
713-
layer_outputs = [inputs] + layer_outputs
714-
tensor_suffix = "input"
715-
for o, l in zip(layer_outputs, model.layers):
716-
export_name = get_export_name_for_keras(l.name, tensor_suffix)
717-
tensors_to_save.append((export_name, o))
718-
for t_name, t_value in tensors_to_save:
719-
self._save_tensor_to_file(t_name, t_value, collections_to_write)
720-
721-
def save_layer_outputs(self, layer_outputs, model=None):
722-
self._save_layer_values(layer_outputs, self.get_collection(CollectionKeys.LAYERS), model)
723-
724-
def save_layer_inputs(self, x, layer_outputs, model=None):
725-
self._save_layer_values(
726-
layer_outputs, self.get_collection(CollectionKeys.LAYERS), model, inputs=x
727-
)
740+
def _save_layer_values(self, logs):
741+
if logs is None:
742+
return
743+
step_collections = self._get_collections_to_save_for_step()
744+
layer_collection = self.get_collection(CollectionKeys.LAYERS)
745+
collections_to_write = {layer_collection} if layer_collection in step_collections else set()
746+
for layer_name, layer_input, layer_output in logs:
747+
# Cast layer_name to str since it can also be of type bytes
748+
# when run with mirrored strategy
749+
if len(layer_input) == 1:
750+
# Layer Inputs are flattened and passed as a list into
751+
# the next layer. Unpacking it speeds up the _make_numpy fn.
752+
layer_input = layer_input[0]
753+
layer_input_tensor_name = get_export_name_for_keras(str(layer_name), "input")
754+
self._save_tensor_to_file(layer_input_tensor_name, layer_input, collections_to_write)
755+
layer_output_tensor_name = get_export_name_for_keras(str(layer_name), "output")
756+
self._save_tensor_to_file(layer_output_tensor_name, layer_output, collections_to_write)
728757

729758
def _write_optimizer_variables(self):
730759
optimizer_collections = self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES)
@@ -951,7 +980,7 @@ def run(*args, **kwargs):
951980
)
952981

953982
self._write_optimizer_variables()
954-
self._save_layer_input_and_outputs(grad_tape=True)
983+
self._save_layer_input_and_outputs()
955984
if not ((isinstance(loss, tf.Tensor)) and hasattr(loss, "numpy")):
956985
return grads
957986
self._add_metric(metric_name="loss", metric_value=loss)

smdebug/tensorflow/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,7 @@ def get_keras_mode(mode):
384384

385385
def is_tf_version_2x():
386386
return version.parse(tf.__version__) >= version.parse("2.0.0")
387+
388+
389+
def is_tf_version_2_3_x():
390+
return version.parse(tf.__version__) >= version.parse("2.3.0")

tests/tensorflow2/test_keras.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,7 @@ def test_include_regex(out_dir, tf_eager_mode):
520520

521521
tr = create_trial_fast_refresh(out_dir)
522522
tnames = tr.tensor_names(collection="custom_coll")
523-
524-
if tf_eager_mode:
525-
assert len(tnames) == (12 if is_tf_2_2() else 8)
526-
else:
527-
assert len(tnames) == 8
523+
assert len(tnames) == 12
528524
for tname in tnames:
529525
assert tr.tensor(tname).value(0) is not None
530526

0 commit comments

Comments
 (0)