Skip to content

Commit ba70ca6

Browse files
yifeimlaurenyu
authored andcommitted
SM_HPS is worth mentioning (#522)
Just found a SM_HPS environmental variable that dumps the original hyperparameters as a json string. This helps preserve types and is definitely worth mentioning! Btw, `type=json.loads` seems to be the correct grammar (and `type=dict` causes errors). See example https://stackoverflow.com/questions/7625786/type-dict-in-argparse-add-argument
1 parent f32a441 commit ba70ca6

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/sagemaker/mxnet/README.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ The training script is very similar to a training script you might run outside o
5555
These artifacts are compressed and uploaded to S3 to an S3 bucket with the same prefix as the model artifacts.
5656
* ``SM_CHANNEL_XXXX``: A string that represents the path to the directory that contains the input data for the specified channel.
5757
For example, if you specify two input channels in the MXNet estimator's ``fit`` call, named 'train' and 'test', the environment variables ``SM_CHANNEL_TRAIN`` and ``SM_CHANNEL_TEST`` are set.
58+
* ``SM_HPS``: A json dump of the hyperparameters preserving json types (boolean, integer, etc.)
5859

5960
For the exhaustive list of available environment variables, see the `SageMaker Containers documentation <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
6061

@@ -66,6 +67,7 @@ For example, a training script might start with the following:
6667
6768
import argparse
6869
import os
70+
import json
6971
7072
if __name__ =='__main__':
7173
@@ -76,6 +78,9 @@ For example, a training script might start with the following:
7678
parser.add_argument('--batch-size', type=int, default=100)
7779
parser.add_argument('--learning-rate', type=float, default=0.1)
7880
81+
# an alternative way to load hyperparameters via SM_HPS environment variable.
82+
parser.add_argument('--sm-hps', type=json.loads, default=os.environ['SM_HPS'])
83+
7984
# input data and model directories
8085
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
8186
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])

0 commit comments

Comments
 (0)