Skip to content

Commit d20a1d3

Browse files
authored
Fix example (aws#191)
1 parent 0a0e0e5 commit d20a1d3

File tree

2 files changed

+38
-47
lines changed

2 files changed

+38
-47
lines changed

docs/tensorflow/examples/resnet50.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,16 @@ so you can run them immediately without having to setup S3 permissions.
9797
### Example commands
9898
#### Saving weights and gradients with Tornasole
9999
```
100-
python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
101-
--tornasole_save_weights --tornasole_save_gradients \
100+
python train_imagenet_resnet_hvd.py --clear_log True --enable_tornasole True \
101+
--tornasole_save_weights True --tornasole_save_gradients True \
102102
--tornasole_step_interval 10 \
103103
--tornasole_path ~/ts_outputs/default
104104
```
105105
#### Simulating gradients which 'vanish'
106106
We simulate the scenario of gradients being really small (vanishing) by initializing weights with a small constant.
107107
```
108-
python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
109-
--tornasole_save_weights --tornasole_save_gradients \
108+
python train_imagenet_resnet_hvd.py --clear_log True --enable_tornasole True \
109+
--tornasole_save_weights True --tornasole_save_gradients True \
110110
--tornasole_step_interval 10 \
111111
--constant_initializer 0.01 \
112112
--tornasole_path ~/ts_outputs/vanishing
@@ -119,15 +119,15 @@ python -m tornasole.rules.rule_invoker --trial-dir ~/ts_outputs/vanishing --rule
119119
```
120120
#### Saving activations of RELU layers in full
121121
```
122-
python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
123-
--tornasole_save_relu_activations \
122+
python train_imagenet_resnet_hvd.py --clear_log True --enable_tornasole True \
123+
--tornasole_save_relu_activations True \
124124
--tornasole_step_interval 10 \
125125
--tornasole_path ~/ts_outputs/full_relu_activations
126126
```
127127
#### Saving activations of RELU layers as reductions
128128
```
129-
python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
130-
--tornasole_save_relu_activations \
129+
python train_imagenet_resnet_hvd.py --clear_log True --enable_tornasole True \
130+
--tornasole_save_relu_activations True \
131131
--tornasole_relu_reductions min max mean variance \
132132
--tornasole_relu_reductions_abs mean variance \
133133
--tornasole_step_interval 10 \
@@ -137,8 +137,8 @@ python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
137137
If you want to compute and track the ratio of weights and updates,
138138
you can do that by saving weights every step as follows
139139
```
140-
python train_imagenet_resnet_hvd.py --clear_log --enable_tornasole \
141-
--tornasole_save_weights \
140+
python train_imagenet_resnet_hvd.py --clear_log True --enable_tornasole True \
141+
--tornasole_save_weights True \
142142
--tornasole_step_interval 1 \
143143
--tornasole_path ~/ts_outputs/weights
144144
```
@@ -160,7 +160,7 @@ python -m tornasole.rules.rule_invoker --trial-dir ~/ts_outputs/weights --rule-n
160160

161161
#### Running with tornasole disabled
162162
```
163-
python train_imagenet_resnet_hvd.py --clear_log
163+
python train_imagenet_resnet_hvd.py --clear_log True
164164
```
165165
### More
166166
Please refer to [Tornasole Tensorflow page](../README.md) and the various flags in the script to customize the behavior further.

examples/tensorflow/scripts/train_imagenet_resnet_hvd.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ def add_cli_args():
853853
by default""")
854854
cmdline.add_argument('-b', '--batch_size', default=128, type=int,
855855
help="""Size of each minibatch per GPU""")
856-
cmdline.add_argument('--num_batches', type=int, default=200,
856+
cmdline.add_argument('--num_batches', type=int, default=1000,
857857
help="""Number of batches to run.
858858
Ignored during eval or if num epochs given""")
859859
cmdline.add_argument('--num_epochs', type=int,
@@ -867,8 +867,8 @@ def add_cli_args():
867867
Pass --clear_log if you want to clear all
868868
checkpoints and start a fresh run""")
869869
cmdline.add_argument('--model_dir', default=None, type=str)
870-
cmdline.add_argument('--random_seed', type=bool, default=False)
871-
cmdline.add_argument('--clear_log', default=False,
870+
cmdline.add_argument('--random_seed', type=str2bool, default=False)
871+
cmdline.add_argument('--clear_log', type=str2bool, default=False,
872872
help="""Clear the log folder passed
873873
so a fresh run can be started""")
874874
cmdline.add_argument('--log_name', type=str, default='hvd_train.log')
@@ -878,7 +878,7 @@ def add_cli_args():
878878
cmdline.add_argument('--display_every', default=20, type=int,
879879
help="""How often (in iterations) to print out
880880
running information.""")
881-
cmdline.add_argument('--eval', default=False,
881+
cmdline.add_argument('--eval', type=str2bool, default=False,
882882
help="""Evaluate the top-1 and top-5 accuracy of
883883
the latest checkpointed model. If you want to
884884
evaluate using multiple GPUs ensure that all
@@ -889,7 +889,7 @@ def add_cli_args():
889889
cmdline.add_argument('--eval_interval', type=int,
890890
help="""Evaluate accuracy per eval_interval
891891
number of epochs""")
892-
cmdline.add_argument('--fp16', default=True,
892+
cmdline.add_argument('--fp16', type=str2bool, default=True,
893893
help="""Train using float16 (half) precision instead
894894
of float32.""")
895895
cmdline.add_argument('--num_gpus', default=1, type=int,
@@ -899,10 +899,10 @@ def add_cli_args():
899899
print during evaluation""")
900900
cmdline.add_argument('--save_checkpoints_steps', type=int, default=1000)
901901
cmdline.add_argument('--save_summary_steps', type=int, default=0)
902-
cmdline.add_argument('--adv_bn_init', default=True,
902+
cmdline.add_argument('--adv_bn_init', type=str2bool, default=True,
903903
help="""init gamme of the last BN of
904904
each ResMod at 0.""")
905-
cmdline.add_argument('--adv_conv_init', default=True,
905+
cmdline.add_argument('--adv_conv_init', type=str2bool, default=True,
906906
help="""init conv with MSRA initializer""")
907907
cmdline.add_argument('--lr', type=float,
908908
help="""Start learning rate""")
@@ -927,7 +927,7 @@ def add_cli_args():
927927
(decay by a factor at specified steps)
928928
or `poly`(polynomial_decay with degree 2)""")
929929

930-
cmdline.add_argument('--use_larc', default=False,
930+
cmdline.add_argument('--use_larc', type=str2bool, default=False,
931931
help="""Use Layer wise Adaptive Rate Control
932932
which helps convergence at really
933933
large batch sizes""")
@@ -950,7 +950,7 @@ def add_cli_args():
950950
cmdline.add_argument('--lc_beta', default=0.00001, type=float,
951951
help="""Liner Cosine Beta""")
952952

953-
cmdline.add_argument('--increased_aug', default=False,
953+
cmdline.add_argument('--increased_aug', type=str2bool, default=False,
954954
help="""Increase augmentations helpful when training
955955
with large number of GPUs such as 128 or 256""")
956956
cmdline.add_argument('--contrast', default=0.6, type=float,
@@ -964,16 +964,16 @@ def add_cli_args():
964964
help="""Brightness factor""")
965965

966966
# tornasole arguments
967-
cmdline.add_argument('--enable_tornasole', default=False,
967+
cmdline.add_argument('--enable_tornasole', type=str2bool, default=False,
968968
help="""enable Tornasole""")
969969
cmdline.add_argument('--tornasole_path',
970970
default='tornasole_outputs/default_run',
971971
help="""Directory in which to write tornasole data.
972972
This can be a local path or
973973
S3 path in the form s3://bucket_name/prefix_name""")
974-
cmdline.add_argument('--tornasole_save_all', default=False,
974+
cmdline.add_argument('--tornasole_save_all', type=str2bool, default=False,
975975
help="""save all tensors""")
976-
cmdline.add_argument('--tornasole_dryrun', default=False,
976+
cmdline.add_argument('--tornasole_dryrun', type=str2bool, default=False,
977977
help="""If enabled, do not write data to disk""")
978978
cmdline.add_argument('--tornasole_exclude', nargs='+', default=[],
979979
type=str, action='append',
@@ -985,10 +985,10 @@ def add_cli_args():
985985
Tornasole's default collection""")
986986
cmdline.add_argument('--tornasole_step_interval', default=10, type=int,
987987
help="""Save tornasole data every N runs""")
988-
cmdline.add_argument('--tornasole_save_weights', default=False)
989-
cmdline.add_argument('--tornasole_save_gradients', default=False)
990-
cmdline.add_argument('--tornasole_save_inputs', default=False)
991-
cmdline.add_argument('--tornasole_save_relu_activations', default=False)
988+
cmdline.add_argument('--tornasole_save_weights', type=str2bool, default=False)
989+
cmdline.add_argument('--tornasole_save_gradients', type=str2bool, default=False)
990+
cmdline.add_argument('--tornasole_save_inputs', type=str2bool, default=False)
991+
cmdline.add_argument('--tornasole_save_relu_activations', type=str2bool, default=False)
992992
cmdline.add_argument('--tornasole_relu_reductions', type=str,
993993
help="""A comma separated list of reductions can be
994994
passed. If passed, saves relu activations
@@ -1022,18 +1022,17 @@ def get_tornasole_hook(FLAGS):
10221022

10231023
include_collections = []
10241024

1025-
if FLAGS.tornasole_save_weights:
1025+
if FLAGS.tornasole_save_weights is True:
10261026
include_collections.append('weights')
1027-
if FLAGS.tornasole_save_gradients:
1027+
if FLAGS.tornasole_save_gradients is True:
10281028
include_collections.append('gradients')
1029-
if FLAGS.tornasole_save_relu_activations:
1029+
if FLAGS.tornasole_save_relu_activations is True:
10301030
include_collections.append('relu_activations')
1031-
if FLAGS.tornasole_save_inputs:
1031+
if FLAGS.tornasole_save_inputs is True:
10321032
include_collections.append('inputs')
1033-
if FLAGS.tornasole_include:
1033+
if FLAGS.tornasole_include is True:
10341034
ts.get_collection('default').include(FLAGS.tornasole_include)
10351035
include_collections.append('default')
1036-
10371036
return ts.TornasoleHook(out_dir=FLAGS.tornasole_path,
10381037
save_config=ts.SaveConfig(
10391038
save_interval=FLAGS.tornasole_step_interval),
@@ -1097,18 +1096,10 @@ def main():
10971096

10981097
logger = logging.getLogger(FLAGS.log_name)
10991098
logger.setLevel(logging.INFO) # INFO, ERROR
1100-
# file handler which logs debug messages
1101-
# console handler
1102-
ch = logging.StreamHandler()
1103-
ch.setLevel(logging.INFO)
1104-
# add formatter to the handlers
1105-
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
1106-
formatter = logging.Formatter('%(message)s')
1107-
ch.setFormatter(formatter)
1108-
logger.addHandler(ch)
11091099
if not hvd.rank():
11101100
fh = logging.FileHandler(os.path.join(FLAGS.log_dir, FLAGS.log_name))
11111101
fh.setLevel(logging.DEBUG)
1102+
formatter = logging.Formatter('%(message)s')
11121103
fh.setFormatter(formatter)
11131104
# add handlers to logger
11141105
logger.addHandler(fh)
@@ -1193,7 +1184,7 @@ def main():
11931184
'model': FLAGS.model,
11941185
'decay_steps': decay_steps,
11951186
'n_classes': 1000,
1196-
'dtype': tf.float16 if FLAGS.fp16 else tf.float32,
1187+
'dtype': tf.float16 if FLAGS.fp16 is True else tf.float32,
11971188
'format': 'channels_first',
11981189
'device': '/gpu:0',
11991190
'lr': FLAGS.lr,
@@ -1224,7 +1215,7 @@ def main():
12241215
save_checkpoints_steps=FLAGS.save_checkpoints_steps if do_checkpoint else None,
12251216
keep_checkpoint_max=None))
12261217

1227-
if FLAGS.enable_tornasole and hvd.rank() == 0:
1218+
if FLAGS.enable_tornasole is True and hvd.rank() == 0:
12281219
hook = get_tornasole_hook(FLAGS)
12291220

12301221
if not FLAGS.eval:
@@ -1238,10 +1229,11 @@ def main():
12381229
num_training_samples,
12391230
FLAGS.display_every,
12401231
logger))
1241-
if FLAGS.enable_tornasole:
1232+
if FLAGS.enable_tornasole is True:
12421233
training_hooks.append(hook)
12431234
try:
1244-
hook.set_mode(ts.modes.TRAIN)
1235+
if FLAGS.enable_tornasole is True:
1236+
hook.set_mode(ts.modes.TRAIN)
12451237
start_time = time.time()
12461238
classifier.train(
12471239
input_fn=lambda: make_dataset(
@@ -1278,7 +1270,6 @@ def main():
12781270
if (not FLAGS.eval_interval) or \
12791271
(i % FLAGS.eval_interval != 0):
12801272
continue
1281-
hook.set_mode(ts.modes.EVAL)
12821273
eval_result = classifier.evaluate(
12831274
input_fn=lambda: make_dataset(
12841275
eval_filenames,

0 commit comments

Comments
 (0)