Skip to content

Commit 1053ee5

Browse files
authored
Save Optimizer Variables With Keras Fit API In Eager Mode (aws#218)
1 parent 46b1797 commit 1053ee5

File tree

6 files changed

+91
-24
lines changed

6 files changed

+91
-24
lines changed

smdebug/tensorflow/base_hook.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,11 @@ def _get_collections_with_tensor(self, tf_tensor_name) -> Set["Collection"]:
399399
# tensors are not matched with collections at preparation time.
400400
# Call core/hook.py's _get_collections_with_tensor() where tensors are
401401
# matched with collections by regex
402-
if self.tape:
402+
if self.tape or (
403+
tf_tensor_name not in self.tensor_to_collections
404+
and is_tf_version_2x()
405+
and tf.executing_eagerly()
406+
):
403407
return super()._get_collections_with_tensor(tf_tensor_name)
404408
return self.tensor_to_collections[tf_tensor_name]
405409

@@ -457,8 +461,8 @@ def set_gradients(self, gradients=None, gradients_and_variables=None):
457461
:param gradients_and_variables: list of tuples [(tf.Tensor/tf.Variable, tf.Tensor/tf.Variable)...]
458462
list of tuples representing gradients and weights
459463
"""
460-
# TF 2.x doesn't provide gradient/optimizer variable names and values by default.
461-
# Skipping set_gradients and set_optimizer_variables for Tf 2.x until there is
464+
# TF 2.x provides only symbolic gradient variables that do not provide access to their values.
465+
# Skipping set_gradients for Tf 2.x until there is
462466
# support to pass names and values from TF side.
463467

464468
# From TF 2.2, executing_eagerly_outside_functions() can be used as
@@ -482,18 +486,12 @@ def set_optimizer_variables(self, optimizer_variables):
482486
This method helps find the optimizer variables (such as momentum)
483487
:param optimizer_variables: list of tf.Variables/tf.Tensors/tf.MirroredVariables
484488
"""
485-
# TF 2.x doesn't provide gradient/optimizer variable names and values by default.
486-
# Skipping set_gradients and set_optimizer_variables for Tf 2.x until there is
487-
# support to pass names and values from TF side.
488-
489489
# From TF 2.2, executing_eagerly_outside_functions() can be used as
490490
# ops.executing_eagerly_outside_functions() or tf.compat.v1.executing_eagerly_outside_functions().
491491
# But in TF 2.1, only ops.executing_eagerly_outside_functions() is valid
492-
if is_tf_version_2x() and ops.executing_eagerly_outside_functions():
493-
return
494492
# since this is done for each variable at a time for keras, not checking if set already
495493
self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES).add_for_mode(
496-
optimizer_variables, ModeKeys.TRAIN
494+
optimizer_variables, self.mode
497495
)
498496

499497
@staticmethod

smdebug/tensorflow/keras.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tensorflow.python.distribute import values
77

88
# First Party
9+
from smdebug.core.collection import DEFAULT_TF_COLLECTIONS
910
from smdebug.core.modes import ModeKeys
1011
from smdebug.core.utils import match_inc
1112
from smdebug.tensorflow.callable_cache import CallableCache
@@ -317,12 +318,28 @@ def _prepare_layers(self, mode):
317318

318319
def _prepare_non_layer_tensors(self):
319320
# for gradients, optimizer_variables
321+
custom_collections = set()
322+
default_tf_collection = set()
323+
320324
for coll in self.collection_manager.get_collections().values():
321-
for tensor_ref in coll.get_tensors():
325+
if coll.name not in DEFAULT_TF_COLLECTIONS:
326+
custom_collections.add(coll)
327+
else:
328+
default_tf_collection.add(coll)
329+
330+
for default_coll in default_tf_collection:
331+
for tensor_ref in default_coll.get_tensors():
322332
if tensor_ref.name not in self.tensor_to_collections:
323-
self.tensor_to_collections[tensor_ref.name] = {coll}
324-
elif coll not in self.tensor_to_collections[tensor_ref.name]:
325-
self.tensor_to_collections[tensor_ref.name].add(coll)
333+
self.tensor_to_collections[tensor_ref.name] = {default_coll}
334+
elif default_coll not in self.tensor_to_collections[tensor_ref.name]:
335+
self.tensor_to_collections[tensor_ref.name].add(default_coll)
336+
337+
# Add tensor to custom collections
338+
for custom_coll in custom_collections:
339+
if match_inc(tensor_ref.name, custom_coll.include_regex):
340+
custom_coll.add_for_mode(tensor_ref.tf_obj, self.mode)
341+
if custom_coll not in self.tensor_to_collections[tensor_ref.name]:
342+
self.tensor_to_collections[tensor_ref.name].add(custom_coll)
326343

327344
def _prepare_tensors_for_step(self, mode):
328345
self.tensor_refs_to_save_this_step = set()
@@ -550,12 +567,32 @@ def on_test_batch_begin(self, batch, logs=None):
550567
def on_predict_batch_begin(self, batch, logs=None):
551568
self._on_any_batch_begin(batch, ModeKeys.PREDICT, logs=logs)
552569

570+
def _write_optimizer_variables(self):
571+
optimizer_collections = self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES)
572+
collections = self._get_collections_to_save_for_step()
573+
for tensor_ref in optimizer_collections.get_tensors(self.mode):
574+
for coll in collections:
575+
if coll in self.tensor_to_collections[tensor_ref.name]:
576+
tensor = tensor_ref.tf_obj
577+
self._save_for_tensor(
578+
tensor_name=tensor.name,
579+
tensor_value=tensor.value(),
580+
check_before_write=False,
581+
)
582+
553583
def _on_any_batch_end(self, batch, mode, logs=None):
554584
if self._is_not_supported():
555585
return
556586

557587
if not is_tf_version_2x() or (is_tf_version_2x() and not tf.executing_eagerly()):
558588
self._remove_fetches_and_callbacks(mode)
589+
590+
if is_tf_version_2x() and tf.executing_eagerly():
591+
# Need to prepare non layer tensors again since
592+
# some tensors only become available on batch end
593+
self._prepare_non_layer_tensors()
594+
self._write_optimizer_variables()
595+
559596
self._save_tensors_post_step(batch, logs)
560597

561598
if self._prepared_tensors[mode]:

tests/tensorflow/keras/test_keras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def test_include_regex(out_dir):
340340
tr = create_trial_fast_refresh(out_dir)
341341
tnames = tr.tensor_names(collection="custom_coll")
342342

343-
assert len(tnames) == 8
343+
assert len(tnames) == 12
344344
for tname in tnames:
345345
assert tr.tensor(tname).value(0) is not None
346346

tests/tensorflow/keras/test_keras_mirrored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def test_include_regex(out_dir):
502502
tr = create_trial_fast_refresh(out_dir)
503503
tnames = tr.tensor_names(collection="custom_coll")
504504

505-
assert len(tnames) == 4 + 3 * strategy.num_replicas_in_sync
505+
assert len(tnames) == 4 + 4 + 3 * strategy.num_replicas_in_sync
506506
for tname in tnames:
507507
assert tr.tensor(tname).value(0) is not None
508508

tests/tensorflow2/test_keras.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from smdebug.core.access_layer import has_training_ended
1818
from smdebug.core.collection import CollectionKeys
1919
from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
20+
from smdebug.core.modes import ModeKeys
2021
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
2122
from smdebug.exceptions import TensorUnavailableForStep
2223
from smdebug.tensorflow import ReductionConfig, SaveConfig
@@ -395,11 +396,21 @@ def test_keras_fit(out_dir, tf_eager_mode, saveall):
395396
# can't save gradients in TF 2.x eager mode
396397
if saveall: # save losses, metrics, weights, biases
397398
if tf_eager_mode:
398-
assert len(trial.tensor_names()) == (7 if is_tf_2_2() else 8)
399+
assert len(trial.tensor_names()) == (12 if is_tf_2_2() else 13)
399400
else:
400401
assert len(trial.tensor_names()) == 21
401402
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
402403
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
404+
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
405+
assert (
406+
len(
407+
trial.tensor_names(
408+
collection=CollectionKeys.OPTIMIZER_VARIABLES, mode=ModeKeys.EVAL
409+
)
410+
)
411+
== 0,
412+
"No Optimizer Variables Should be Saved in EVAL Mode",
413+
)
403414
else: # save the default losses and metrics
404415
assert len(trial.tensor_names()) == (3 if is_tf_2_2() and tf_eager_mode else 4)
405416
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
@@ -540,6 +551,7 @@ def test_include_collections(out_dir, tf_eager_mode):
540551
CollectionKeys.OUTPUTS,
541552
CollectionKeys.METRICS,
542553
CollectionKeys.OPTIMIZER_VARIABLES,
554+
"custom_optimizer_variables",
543555
]
544556
save_config = SaveConfig(save_interval=3)
545557
hook = smd.KerasHook(
@@ -548,16 +560,18 @@ def test_include_collections(out_dir, tf_eager_mode):
548560
include_collections=include_collections,
549561
reduction_config=ReductionConfig(norms=ALLOWED_NORMS, reductions=ALLOWED_REDUCTIONS),
550562
)
563+
hook.get_collection("custom_optimizer_variables").include("Adam")
551564
helper_keras_fit(out_dir, hook=hook, steps=["train", "eval", "predict"], eager=tf_eager_mode)
552565

553566
trial = smd.create_trial(path=out_dir)
554567
# can't save gradients in TF 2.x
555568
if tf_eager_mode:
556-
assert len(trial.tensor_names()) == (7 if is_tf_2_2() else 8)
569+
assert len(trial.tensor_names()) == (12 if is_tf_2_2() else 13)
557570
else:
558571
assert len(trial.tensor_names()) == 18
559572
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4
560-
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
573+
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
574+
assert len(trial.tensor_names(collection="custom_optimizer_variables")) == 5
561575
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
562576
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
563577
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
@@ -566,6 +580,24 @@ def test_include_collections(out_dir, tf_eager_mode):
566580
)
567581

568582

583+
@pytest.mark.slow
584+
def test_include_only_custom_collection(out_dir, tf_eager_mode):
585+
include_collections = ["custom_optimizer_variables"]
586+
save_config = SaveConfig(save_interval=3)
587+
hook = smd.KerasHook(
588+
out_dir,
589+
save_config=save_config,
590+
include_collections=include_collections,
591+
reduction_config=ReductionConfig(norms=ALLOWED_NORMS, reductions=ALLOWED_REDUCTIONS),
592+
)
593+
hook.get_collection("custom_optimizer_variables").include("Adam")
594+
helper_keras_fit(out_dir, hook=hook, steps=["train", "eval", "predict"], eager=tf_eager_mode)
595+
596+
trial = smd.create_trial(path=out_dir)
597+
assert len(trial.tensor_names()) == (8 if is_tf_2_2() and tf_eager_mode else 9)
598+
assert len(trial.tensor_names(collection="custom_optimizer_variables")) == 5
599+
600+
569601
@pytest.mark.slow
570602
def test_hook_from_json(out_dir, tf_eager_mode, monkeypatch):
571603
monkeypatch.setenv(

tests/tensorflow2/test_keras_mirrored.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ def exhaustive_check(trial_dir, include_workers="one", eager=True):
158158
if include_workers == "all":
159159
assert len(tr.workers()) == strategy.num_replicas_in_sync
160160
if eager:
161-
assert len(tr.tensor_names()) == (6 + 1 + 2 if is_tf_2_2() else 6 + 1 + 3)
162-
# 6 weights, 1 loss, 3 metrics for Tf 2.1
163-
# 6 weights, 1 loss, 2 metrics for Tf 2.2
161+
assert len(tr.tensor_names()) == (6 + 1 + 2 + 5 if is_tf_2_2() else 6 + 1 + 3 + 5)
162+
# 6 weights, 1 loss, 3 metrics, 5 optimizer variables for Tf 2.1
163+
# 6 weights, 1 loss, 2 metrics, 5 optimizer variables for Tf 2.2
164164
else:
165165
assert len(tr.tensor_names()) == (6 + 6 + 1 + 3 + strategy.num_replicas_in_sync * 3 + 5)
166166
else:
@@ -245,8 +245,8 @@ def test_save_all(out_dir, tf_eager_mode):
245245
tr = create_trial_fast_refresh(out_dir)
246246
print(tr.tensor_names())
247247
if tf_eager_mode:
248-
assert len(tr.tensor_names()) == (6 + 2 + 1 if is_tf_2_2() else 6 + 3 + 1)
249-
# weights, metrics, losses
248+
assert len(tr.tensor_names()) == (6 + 2 + 1 + 5 if is_tf_2_2() else 6 + 3 + 1 + 5)
249+
# weights, metrics, losses, optimizer variables
250250
else:
251251
assert (
252252
len(tr.tensor_names())

0 commit comments

Comments
 (0)