Skip to content

Commit 749bded

Browse files
authored
TF 2.x: Support for keras to estimator (aws#268)
1 parent 94acc66 commit 749bded

File tree

7 files changed

+124
-10
lines changed

7 files changed

+124
-10
lines changed

docs/sagemaker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Here's a list of frameworks and versions which support this experience.
2727

2828
| Framework | Version |
2929
| --- | --- |
30-
| [TensorFlow](tensorflow.md) | 1.15, 2.1 |
30+
| [TensorFlow](tensorflow.md) | 1.15, 2.1, 2.2 |
3131
| [MXNet](mxnet.md) | 1.6 |
3232
| [PyTorch](pytorch.md) | 1.4, 1.5 |
3333
| [XGBoost](xgboost.md) | >=0.90-2 [As Built-in algorithm](xgboost.md#use-xgboost-as-a-built-in-algorithm)|

docs/tensorflow.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@
1515
### Versions
1616
- Zero Script Change experience where you need no modifications to your training script is supported in the official [SageMaker Framework Container for TensorFlow 1.15](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html), or the [AWS Deep Learning Container for TensorFlow 1.15](https://aws.amazon.com/machine-learning/containers/).
1717

18-
- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0.1, 2.1.0. Keras 2.3.
18+
- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0+. Keras 2.3.
1919

2020
### Interfaces
21-
- [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator)
22-
- [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras)
23-
- [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en)
21+
- TF 1.x:
22+
- [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator)
23+
- [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras)
24+
- [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en)
25+
- TF 2.x:
26+
- [Estimator](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/estimator)
27+
- [tf.keras](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras)
28+
2429

2530
### Distributed training
2631
- [MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/distribute/MirroredStrategy) or [Contrib MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/distribute/MirroredStrategy)
2732

28-
We will very quickly follow up with support for Horovod and Parameter Server based training.
33+
We will very quickly follow up with support for Parameter Server based training.
2934

3035
---
3136

smdebug/tensorflow/collection.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ def __init__(self, collections=None, create_default=True):
148148
self.create_collection(n)
149149
if is_tf_version_2x() and tf.executing_eagerly():
150150
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
151-
self.get(CollectionKeys.WEIGHTS).include("^weights/.*/((?!bias).)*$")
152-
self.get(CollectionKeys.LOSSES).include(".*loss.*")
153-
self.get(CollectionKeys.GRADIENTS).include("^gradient")
154151
else:
155152
self.get(CollectionKeys.BIASES).include("bias")
156153

smdebug/tensorflow/keras.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,11 @@ def run(*args, **kwargs):
716716
# at this point we need all collections to be ready
717717
# this may not be the case at creation of hook
718718
# as user's code after hook might add collections
719+
self.collection_manager.get(CollectionKeys.WEIGHTS).include(
720+
"^weights/.*/((?!bias).)*$"
721+
)
722+
self.collection_manager.get(CollectionKeys.LOSSES).include(".*loss.*")
723+
self.collection_manager.get(CollectionKeys.GRADIENTS).include("^gradient")
719724
self._prepare_collections()
720725
self.prepared_collections = True
721726

tests/tensorflow2/test_estimator.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Standard Library
2+
# Third Party
3+
import pytest
4+
import tensorflow.compat.v2 as tf
5+
from tests.zero_code_change.tf_utils import get_estimator, get_input_fns
6+
7+
# First Party
8+
import smdebug.tensorflow as smd
9+
10+
11+
@pytest.mark.parametrize("saveall", [True, False])
12+
def test_estimator(out_dir, tf_eager_mode, saveall):
13+
""" Works as intended. """
14+
if tf_eager_mode is False:
15+
tf.compat.v1.disable_eager_execution()
16+
tf.compat.v1.reset_default_graph()
17+
tf.keras.backend.clear_session()
18+
mnist_classifier = get_estimator()
19+
train_input_fn, eval_input_fn = get_input_fns()
20+
21+
# Train and evaluate
22+
train_steps, eval_steps = 8, 2
23+
hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall)
24+
hook.set_mode(mode=smd.modes.TRAIN)
25+
mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook])
26+
hook.set_mode(mode=smd.modes.EVAL)
27+
mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook])
28+
29+
# Check that hook created and tensors saved
30+
trial = smd.create_trial(path=out_dir)
31+
tnames = trial.tensor_names()
32+
assert len(trial.steps()) > 0
33+
if saveall:
34+
assert len(tnames) >= 301
35+
else:
36+
assert len(tnames) == 1
37+
38+
39+
@pytest.mark.parametrize("saveall", [True, False])
40+
def test_linear_classifier(out_dir, tf_eager_mode, saveall):
41+
""" Works as intended. """
42+
if tf_eager_mode is False:
43+
tf.compat.v1.disable_eager_execution()
44+
tf.compat.v1.reset_default_graph()
45+
tf.keras.backend.clear_session()
46+
train_input_fn, eval_input_fn = get_input_fns()
47+
x_feature = tf.feature_column.numeric_column("x", shape=(28, 28))
48+
estimator = tf.estimator.LinearClassifier(
49+
feature_columns=[x_feature], model_dir="/tmp/mnist_linear_classifier", n_classes=10
50+
)
51+
hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall)
52+
estimator.train(input_fn=train_input_fn, steps=10, hooks=[hook])
53+
54+
# Check that hook created and tensors saved
55+
trial = smd.create_trial(path=out_dir)
56+
tnames = trial.tensor_names()
57+
assert len(trial.steps()) > 0
58+
if saveall:
59+
assert len(tnames) >= 224
60+
else:
61+
assert len(tnames) == 2

tests/tensorflow2/test_keras.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Third Party
1313
import pytest
1414
import tensorflow.compat.v2 as tf
15+
import tensorflow_datasets as tfds
1516
from tests.tensorflow2.utils import is_tf_2_2
1617
from tests.tensorflow.utils import create_trial_fast_refresh
1718

@@ -649,3 +650,47 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
649650
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
650651
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
651652
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
653+
654+
655+
def test_keras_to_estimator(out_dir, tf_eager_mode):
656+
if not tf_eager_mode:
657+
tf.compat.v1.disable_eager_execution()
658+
tf.compat.v1.reset_default_graph()
659+
660+
tf.keras.backend.clear_session()
661+
662+
model = tf.keras.models.Sequential(
663+
[
664+
tf.keras.layers.Dense(16, activation="relu", input_shape=(4,)),
665+
tf.keras.layers.Dropout(0.2),
666+
tf.keras.layers.Dense(1, activation="sigmoid"),
667+
]
668+
)
669+
670+
def input_fn():
671+
split = tfds.Split.TRAIN
672+
dataset = tfds.load("iris", split=split, as_supervised=True)
673+
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
674+
dataset = dataset.batch(32).repeat()
675+
return dataset
676+
677+
model.compile(loss="categorical_crossentropy", optimizer="adam")
678+
model.summary()
679+
680+
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=out_dir)
681+
682+
hook = smd.EstimatorHook(out_dir)
683+
684+
hook.set_mode(smd.modes.TRAIN)
685+
keras_estimator.train(input_fn=input_fn, steps=25, hooks=[hook])
686+
687+
hook.set_mode(smd.modes.EVAL)
688+
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10, hooks=[hook])
689+
690+
from smdebug.trials import create_trial
691+
692+
tr = create_trial(out_dir)
693+
assert len(tr.tensor_names()) == 1
694+
assert len(tr.steps()) == 2
695+
assert len(tr.steps(smd.modes.TRAIN)) == 1
696+
assert len(tr.steps(smd.modes.EVAL)) == 1

tests/zero_code_change/tf_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
import tensorflow.compat.v1 as tf
77
import tensorflow_datasets as tfds
8-
from tensorflow.examples.tutorials.mnist import input_data
98

109
tfds.disable_progress_bar()
1110

@@ -232,6 +231,8 @@ def neural_net(x):
232231

233232

234233
def get_data() -> "tf.contrib.learn.python.learn.datasets.base.Datasets":
234+
from tensorflow.examples.tutorials.mnist import input_data
235+
235236
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
236237
return mnist
237238

0 commit comments

Comments
 (0)