Skip to content

Commit 0167e9e

Browse files
authored
TensorFlow 1.6 documentation updates (aws#124)
* Update docs for TensorFlow 1.6 changes * Updated TF special hyperparameters docs
1 parent 54c764c commit 0167e9e

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

README.rst

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -752,12 +752,25 @@ Preparing the TensorFlow training script
752752
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
753753
754754
Your TensorFlow training script must be a **Python 2.7** source file. The current supported TensorFlow
755-
versions are **1.6.0 (default)**, **1.5.0**, and **1.4.1**. This training script **must contain** the following functions:
755+
versions are **1.6.0 (default)**, **1.5.0**, and **1.4.1**. The SageMaker TensorFlow docker image
756+
uses this script by calling specifically-named functions from this script.
757+
758+
The training script **must contain** the following:
759+
760+
- Exactly one of the following:
761+
762+
- ``model_fn``: defines the model that will be trained.
763+
- ``keras_model_fn``: defines the ``tf.keras`` model that will be trained.
764+
- ``estimator_fn``: defines the ``tf.estimator.Estimator`` that will train the model.
756765
757-
- ``model_fn``: defines the model that will be trained.
758766
- ``train_input_fn``: preprocess and load training data.
759767
- ``eval_input_fn``: preprocess and load evaluation data.
760-
- ``serving_input_fn``: defines the features to be passed to the model during prediction.
768+
769+
In addition, it may optionally contain:
770+
771+
- ``serving_input_fn``: Defines the features to be passed to the model during prediction. **Important:**
772+
this function is used only during training, but is required to deploy the model resulting from training
773+
in a SageMaker endpoint.
761774
762775
Creating a ``model_fn``
763776
^^^^^^^^^^^^^^^^^^^^^^^
@@ -793,6 +806,8 @@ The ``model_fn`` must accept four positional arguments:
793806
your TensorFlow training script. You can use this to pass hyperparameters to your
794807
training script.
795808
809+
The ``model_fn`` must return a ``tf.estimator.EstimatorSpec``.
810+
796811
Example of a complete ``model_fn``
797812
''''''''''''''''''''''''''''''''''
798813
@@ -875,9 +890,9 @@ The basic skeleton for the ``train_input_fn`` looks like this:
875890
# Logic to the following:
876891
# 1. Reads the **training** dataset files located in training_dir
877892
# 2. Preprocess the dataset
878-
# 3. Return 1) a mapping of feature columns to Tensors with
893+
# 3. Return 1) a dict of feature names to Tensors with
879894
# the corresponding feature data, and 2) a Tensor containing labels
880-
return feature_cols, labels
895+
return features, labels
881896
882897
An ``eval_input_fn`` follows the same format:
883898
@@ -887,9 +902,12 @@ An ``eval_input_fn`` follows the same format:
887902
# Logic to the following:
888903
# 1. Reads the **evaluation** dataset files located in training_dir
889904
# 2. Preprocess the dataset
890-
# 3. Return 1) a mapping of feature columns to Tensors with
905+
# 3. Return 1) a dict of feature names to Tensors with
891906
# the corresponding feature data, and 2) a Tensor containing labels
892-
return feature_cols, labels
907+
return features, labels
908+
909+
**Note:** For TensorFlow 1.4 and 1.5, ``train_input_fn`` and ``eval_input_fn`` may also return a no-argument
910+
function which returns the tuple ``features, labels``. This is no longer supported for TensorFlow 1.6 and up.
893911
894912
Example of a complete ``train_input_fn`` and ``eval_input_fn``
895913
''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
@@ -922,14 +940,9 @@ More details on how to create input functions can be find in `Building Input Fun
922940
Creating a ``serving_input_fn``
923941
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
924942
925-
``serving_input_fn`` is used to define the shapes and types of the inputs
926-
the model accepts when the model is exported for Tensorflow Serving. ``serving_input_fn`` is called
927-
at the end of model training and is not called during inference. (If you'd like to preprocess inference data,
928-
please see ``input_fn``). This function has the following purposes:
943+
``serving_input_fn`` is used to define the shapes and types of the inputs the model accepts when the model is exported for Tensorflow Serving. It is optional, but required for deploying the trained model to a SageMaker endpoint.
929944
930-
- To add placeholders to the graph that the serving system will feed with inference requests.
931-
- To add any additional ops needed to convert data from the input format into the feature Tensors
932-
expected by the model.
945+
``serving_input_fn`` is called at the end of model training and is **not** called during inference. (If you'd like to preprocess inference data, please see **Overriding input preprocessing with an input_fn**).
933946
934947
The basic skeleton for the ``serving_input_fn`` looks like this:
935948
@@ -939,8 +952,10 @@ The basic skeleton for the ``serving_input_fn`` looks like this:
939952
# Logic to the following:
940953
# 1. Defines placeholders that TensorFlow serving will feed with inference requests
941954
# 2. Preprocess input data
942-
# 3. Returns a tf.estimator.export.ServingInputReceiver object, which packages the placeholders
943-
and the resulting feature Tensors together.
955+
# 3. Returns a tf.estimator.export.ServingInputReceiver or tf.estimator.export.TensorServingInputReceiver,
956+
# which packages the placeholders and the resulting feature Tensors together.
957+
958+
**Note:** For TensorFlow 1.4 and 1.5, ``serving_input_fn`` may also return a no-argument function which returns a ``tf.estimator.export.ServingInputReceiver`` or``tf.estimator.export.TensorServingInputReceiver``. This is no longer supported for TensorFlow 1.6 and up.
944959
945960
Example of a complete ``serving_input_fn``
946961
''''''''''''''''''''''''''''''''''''''''''
@@ -1137,15 +1152,23 @@ These hyperparameters are used by TensorFlow to fine tune the training.
11371152
You need to add them inside the hyperparameters dictionary in the
11381153
``TensorFlow`` estimator constructor.
11391154
1155+
**All versions**
1156+
11401157
- ``save_summary_steps (int)`` Save summaries every this many steps.
11411158
- ``save_checkpoints_secs (int)`` Save checkpoints every this many seconds. Can not be specified with ``save_checkpoints_steps``.
11421159
- ``save_checkpoints_steps (int)`` Save checkpoints every this many steps. Can not be specified with ``save_checkpoints_secs``.
11431160
- ``keep_checkpoint_max (int)`` The maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)
11441161
- ``keep_checkpoint_every_n_hours (int)`` Number of hours between each checkpoint to be saved. The default value of 10,000 hours effectively disables the feature.
11451162
- ``log_step_count_steps (int)`` The frequency, in number of global steps, that the global step/sec will be logged during training.
1163+
1164+
**TensorFlow 1.6 and up**
1165+
1166+
- ``start_delay_secs (int)`` See docs for this parameter in `tf.estimator.EvalSpec <https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec>`_.
1167+
- ``throttle_secs (int)`` See docs for this parameter in `tf.estimator.EvalSpec <https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec>`_.
1168+
1169+
**TensorFlow 1.4 and 1.5**
1170+
11461171
- ``eval_metrics (dict)`` ``dict`` of string, metric function. If `None`, default set is used. This should be ``None`` if the ``estimator`` is `tf.estimator.Estimator <https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator>`_. If metrics are provided they will be *appended* to the default set.
1147-
- ``train_monitors (list)`` A list of monitors to pass during training.
1148-
- ``eval_hooks (list)`` A list of `SessionRunHook` hooks to pass during evaluation.
11491172
- ``eval_delay_secs (int)`` Start evaluating after waiting for this many seconds.
11501173
- ``continuous_eval_throttle_secs (int)`` Do not re-evaluate unless the last evaluation was started at least this many seconds ago.
11511174
- ``min_eval_frequency (int)`` The minimum number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. If None, defaults to 1000.
@@ -1398,7 +1421,7 @@ This process looks like this:
13981421
13991422
The common functionality can be extended by the addiction of the following two functions to your training script:
14001423
1401-
Overriding input precessing with an ``input_fn``
1424+
Overriding input preprocessing with an ``input_fn``
14021425
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14031426
14041427
An example of ``input_fn`` for the content-type "application/python-pickle" can be seen below:
@@ -1417,7 +1440,7 @@ An example of ``input_fn`` for the content-type "application/python-pickle" can
14171440
# if the content type is not supported.
14181441
pass
14191442
1420-
Overriding output precessing with an ``output_fn``
1443+
Overriding output postprocessing with an ``output_fn``
14211444
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14221445
14231446
An example of ``output_fn`` for the accept type "application/python-pickle" can be seen below:

0 commit comments

Comments
 (0)