30
30
get_model_input_export_name ,
31
31
get_model_output_export_name ,
32
32
is_keras_optimizer ,
33
+ is_tf_version_2_3_x ,
33
34
is_tf_version_2x ,
34
35
)
35
36
@@ -71,6 +72,14 @@ def __init__(
71
72
) # stores tensors custom tensors saved by users every step
72
73
self .saved_layers = dict ()
73
74
self .has_registered_model = False
75
+ # supports_tf_logs property was introduced in TF 2.3.0
76
+ # it indicates to the framework that the callback is not
77
+ # limited to reading only numpy logs
78
+ self ._supports_tf_logs = True
79
+ # TF 2.3.0 has a callback ordering bug
80
+ # this flag indicated to the train_batch_begin callback
81
+ # the the step was already incremented in the on_train_begin callback
82
+ self .step_incremented_in_on_train_begin = False
74
83
75
84
def _is_not_supported (self ):
76
85
if self .distribution_strategy is None :
@@ -109,7 +118,8 @@ def register_model(self, model):
109
118
# It attaches a hook to every layer of the model to capture
110
119
# layer values
111
120
self .model = model
112
- self ._wrap_model_with_input_output_saver ()
121
+ if self .tape is not None :
122
+ self ._wrap_model_with_input_output_saver ()
113
123
self .has_registered_model = True
114
124
115
125
def _get_matching_collections (
@@ -348,7 +358,10 @@ def _prepare_tensors_available_post_step(self):
348
358
349
359
# Add tensor to custom collections
350
360
for custom_coll in custom_collections :
351
- if match_inc (tensor_ref .name , custom_coll .include_regex ):
361
+ if (
362
+ match_inc (tensor_ref .name , custom_coll .include_regex )
363
+ and tensor_ref .tf_obj is not None
364
+ ):
352
365
custom_coll .add_for_mode (tensor_ref .tf_obj , self .mode )
353
366
if custom_coll not in self .tensor_to_collections [tensor_ref .name ]:
354
367
self .tensor_to_collections [tensor_ref .name ].add (custom_coll )
@@ -390,6 +403,12 @@ def _save_custom_tensors_post_step(self):
390
403
self ._save_tensor_to_file (tensor_name , tensor_value , collection_names )
391
404
self .custom_tensors_to_save .clear ()
392
405
406
+ def should_save_layer (self , layer_name ):
407
+ # Called in AWS TF to determine
408
+ # if a particular layer value
409
+ # should be saved
410
+ return self .should_save_tensor_or_collection (layer_name , CollectionKeys .LAYERS )
411
+
393
412
def _save_tensor_to_file (self , tensor_name , tensor_value , collections ):
394
413
if isinstance (collections , set ) is False :
395
414
collections = {collections }
@@ -418,6 +437,31 @@ def _save_tensor_to_file(self, tensor_name, tensor_value, collections):
418
437
collection .set_tensor_ref (tensor_ref )
419
438
self ._save_for_tensor (tensor_name , t , check_before_write = True )
420
439
440
+ def save_gradients_from_logs (self , gradients ):
441
+ if gradients is not None :
442
+ gradient_collection = self .get_collection (CollectionKeys .GRADIENTS )
443
+ step_collections = self ._get_collections_to_save_for_step ()
444
+ collections_to_write = (
445
+ {gradient_collection } if gradient_collection in step_collections else set ()
446
+ )
447
+ if gradients and isinstance (gradients [0 ], tuple ) is False :
448
+ gradients = zip (self .model .trainable_variables , gradients )
449
+ for v , g in gradients :
450
+ if isinstance (v , tf .Tensor ):
451
+ # Tensor.name is meaningless with eager execution
452
+ layer_name = str (v .numpy (), "utf-8" )
453
+ elif isinstance (v , tf .Variable ):
454
+ layer_name = v .name
455
+ else :
456
+ layer_name = v
457
+ layer_name = layer_name .split (":" )[0 ]
458
+ export_name = "gradients/" + layer_name + "Grad"
459
+ if isinstance (g , IndexedSlices ):
460
+ # This class is a simple wrapper for a pair of Tensor objects
461
+ # See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices
462
+ g = g .values
463
+ self ._save_tensor_to_file (export_name , g , collections_to_write )
464
+
421
465
def save_smdebug_logs (self , logs ):
422
466
if logs is None :
423
467
return
@@ -437,24 +481,10 @@ def save_smdebug_logs(self, logs):
437
481
)
438
482
# Save Gradients
439
483
elif key == SMDEBUG_GRADIENTS_KEY :
440
- gradients = logs [key ]
441
- if gradients is not None :
442
- for g , v in zip (gradients , self .model .trainable_variables ):
443
- layer_name = v .name
444
- if len (layer_name .split (":" )) > 1 :
445
- layer_name = layer_name .split (":" )[0 ]
446
- export_name = "gradients/" + layer_name + "Grad"
447
- if isinstance (g , IndexedSlices ):
448
- # This class is a simple wrapper for a pair of Tensor objects
449
- # See: https://www.tensorflow.org/api_docs/python/tf/IndexedSlices
450
- g = g .values
451
- tensors_to_save .append ((export_name , g ))
452
- collections_to_write = {self .get_collection (CollectionKeys .GRADIENTS )}
484
+ self .save_gradients_from_logs (logs [key ])
453
485
# Save Intermediate Layers
454
486
elif key == SMDEBUG_LAYER_OUTPUTS_KEY :
455
- layer_outputs = logs [key ]
456
- self .save_layer_outputs (layer_outputs )
457
- self .save_layer_inputs (logs [ModelInput .INPUTS ], layer_outputs )
487
+ self ._save_layer_values (logs [key ])
458
488
# Save Model Inputs
459
489
elif key in ModelInputs :
460
490
export_name = get_model_input_export_name ()
@@ -489,10 +519,9 @@ def _save_metrics(self, batch, logs, force_save=False):
489
519
self ._add_metric (metric_name = key )
490
520
self ._save_for_tensor (key , logs [key ], check_before_write = False )
491
521
492
- def _save_layer_input_and_outputs (self , grad_tape = False ):
493
- # Iterates over all the saved layers for input and output values
494
- if is_tf_version_2x () is False or (grad_tape is False and self .model .run_eagerly is False ):
495
- # This function only works when the run_eagerly is True
522
+ def _save_layer_input_and_outputs (self ):
523
+ # Run only for GradTape
524
+ if self .tape is None :
496
525
return
497
526
for layer_name in self .saved_layers :
498
527
# Save Input
@@ -520,7 +549,6 @@ def _save_tensors_post_step(self, batch, logs):
520
549
# weights, metrics
521
550
self ._save_metrics (batch , logs )
522
551
self .save_smdebug_logs (logs )
523
- self ._save_layer_input_and_outputs ()
524
552
self ._save_custom_tensors_post_step ()
525
553
526
554
if is_tf_version_2x () and tf .executing_eagerly ():
@@ -615,6 +643,13 @@ def _on_any_mode_begin(self, mode):
615
643
self .graph = tf .get_default_graph ()
616
644
self .set_mode (mode )
617
645
646
+ if self .prepared_collections is False and is_tf_version_2_3_x ():
647
+ # Addresses ordering issues in TF 2.3.0
648
+ # sets prepared_collections to True here
649
+ self ._prepare_collections ()
650
+ self ._increment_step ()
651
+ self .step_incremented_in_on_train_begin = True
652
+
618
653
# have to clear callable cache if we are not caching per mode
619
654
self .callable_cache .change_mode ()
620
655
@@ -658,7 +693,12 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
658
693
# Write the gradients of the past step if the writer is still available.
659
694
if self .writer is not None or len (self .writer_map ):
660
695
self ._close_writers ()
661
- self ._increment_step ()
696
+
697
+ # Addresses callback ordering bug in TF 2.3.0
698
+ if self .step_incremented_in_on_train_begin is False :
699
+ self ._increment_step ()
700
+ else :
701
+ self .step_incremented_in_on_train_begin = False
662
702
663
703
if self .prepared_collections is False :
664
704
# sets prepared_collections to True here
@@ -668,7 +708,6 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
668
708
if (is_tf_version_2x () and tf .executing_eagerly ()) or self ._validate_exec_function (
669
709
self ._get_exec_function (mode )
670
710
):
671
- self ._wrap_model_with_input_output_saver ()
672
711
self ._prepare_layers (mode )
673
712
self ._prepare_tensors_available_post_step ()
674
713
self ._prepared_tensors [mode ] = True
@@ -698,33 +737,23 @@ def on_test_batch_begin(self, batch, logs=None):
698
737
def on_predict_batch_begin (self , batch , logs = None ):
699
738
self ._on_any_batch_begin (batch , ModeKeys .PREDICT , logs = logs )
700
739
701
- def _save_layer_values (self , layer_outputs , collection , model = None , inputs = None ):
702
- if model is None :
703
- if self .model :
704
- model = self .model
705
- else :
706
- return
707
- if layer_outputs is not None :
708
- tensors_to_save = []
709
- step_collections = self ._get_collections_to_save_for_step ()
710
- collections_to_write = {collection } if collection in step_collections else set ()
711
- tensor_suffix = "output"
712
- if inputs is not None :
713
- layer_outputs = [inputs ] + layer_outputs
714
- tensor_suffix = "input"
715
- for o , l in zip (layer_outputs , model .layers ):
716
- export_name = get_export_name_for_keras (l .name , tensor_suffix )
717
- tensors_to_save .append ((export_name , o ))
718
- for t_name , t_value in tensors_to_save :
719
- self ._save_tensor_to_file (t_name , t_value , collections_to_write )
720
-
721
- def save_layer_outputs (self , layer_outputs , model = None ):
722
- self ._save_layer_values (layer_outputs , self .get_collection (CollectionKeys .LAYERS ), model )
723
-
724
- def save_layer_inputs (self , x , layer_outputs , model = None ):
725
- self ._save_layer_values (
726
- layer_outputs , self .get_collection (CollectionKeys .LAYERS ), model , inputs = x
727
- )
740
+ def _save_layer_values (self , logs ):
741
+ if logs is None :
742
+ return
743
+ step_collections = self ._get_collections_to_save_for_step ()
744
+ layer_collection = self .get_collection (CollectionKeys .LAYERS )
745
+ collections_to_write = {layer_collection } if layer_collection in step_collections else set ()
746
+ for layer_name , layer_input , layer_output in logs :
747
+ # Cast layer_name to str since it can also be of type bytes
748
+ # when run with mirrored strategy
749
+ if len (layer_input ) == 1 :
750
+ # Layer Inputs are flattened and passed as a list into
751
+ # the next layer. Unpacking it speeds up the _make_numpy fn.
752
+ layer_input = layer_input [0 ]
753
+ layer_input_tensor_name = get_export_name_for_keras (str (layer_name ), "input" )
754
+ self ._save_tensor_to_file (layer_input_tensor_name , layer_input , collections_to_write )
755
+ layer_output_tensor_name = get_export_name_for_keras (str (layer_name ), "output" )
756
+ self ._save_tensor_to_file (layer_output_tensor_name , layer_output , collections_to_write )
728
757
729
758
def _write_optimizer_variables (self ):
730
759
optimizer_collections = self .collection_manager .get (CollectionKeys .OPTIMIZER_VARIABLES )
@@ -951,7 +980,7 @@ def run(*args, **kwargs):
951
980
)
952
981
953
982
self ._write_optimizer_variables ()
954
- self ._save_layer_input_and_outputs (grad_tape = True )
983
+ self ._save_layer_input_and_outputs ()
955
984
if not ((isinstance (loss , tf .Tensor )) and hasattr (loss , "numpy" )):
956
985
return grads
957
986
self ._add_metric (metric_name = "loss" , metric_value = loss )
0 commit comments