@@ -853,7 +853,7 @@ def add_cli_args():
853
853
by default""" )
854
854
cmdline .add_argument ('-b' , '--batch_size' , default = 128 , type = int ,
855
855
help = """Size of each minibatch per GPU""" )
856
- cmdline .add_argument ('--num_batches' , type = int , default = 200 ,
856
+ cmdline .add_argument ('--num_batches' , type = int , default = 1000 ,
857
857
help = """Number of batches to run.
858
858
Ignored during eval or if num epochs given""" )
859
859
cmdline .add_argument ('--num_epochs' , type = int ,
@@ -867,8 +867,8 @@ def add_cli_args():
867
867
Pass --clear_log if you want to clear all
868
868
checkpoints and start a fresh run""" )
869
869
cmdline .add_argument ('--model_dir' , default = None , type = str )
870
- cmdline .add_argument ('--random_seed' , type = bool , default = False )
871
- cmdline .add_argument ('--clear_log' , default = False ,
870
+ cmdline .add_argument ('--random_seed' , type = str2bool , default = False )
871
+ cmdline .add_argument ('--clear_log' , type = str2bool , default = False ,
872
872
help = """Clear the log folder passed
873
873
so a fresh run can be started""" )
874
874
cmdline .add_argument ('--log_name' , type = str , default = 'hvd_train.log' )
@@ -878,7 +878,7 @@ def add_cli_args():
878
878
cmdline .add_argument ('--display_every' , default = 20 , type = int ,
879
879
help = """How often (in iterations) to print out
880
880
running information.""" )
881
- cmdline .add_argument ('--eval' , default = False ,
881
+ cmdline .add_argument ('--eval' , type = str2bool , default = False ,
882
882
help = """Evaluate the top-1 and top-5 accuracy of
883
883
the latest checkpointed model. If you want to
884
884
evaluate using multiple GPUs ensure that all
@@ -889,7 +889,7 @@ def add_cli_args():
889
889
cmdline .add_argument ('--eval_interval' , type = int ,
890
890
help = """Evaluate accuracy per eval_interval
891
891
number of epochs""" )
892
- cmdline .add_argument ('--fp16' , default = True ,
892
+ cmdline .add_argument ('--fp16' , type = str2bool , default = True ,
893
893
help = """Train using float16 (half) precision instead
894
894
of float32.""" )
895
895
cmdline .add_argument ('--num_gpus' , default = 1 , type = int ,
@@ -899,10 +899,10 @@ def add_cli_args():
899
899
print during evaluation""" )
900
900
cmdline .add_argument ('--save_checkpoints_steps' , type = int , default = 1000 )
901
901
cmdline .add_argument ('--save_summary_steps' , type = int , default = 0 )
902
- cmdline .add_argument ('--adv_bn_init' , default = True ,
902
+ cmdline .add_argument ('--adv_bn_init' , type = str2bool , default = True ,
903
903
help = """init gamme of the last BN of
904
904
each ResMod at 0.""" )
905
- cmdline .add_argument ('--adv_conv_init' , default = True ,
905
+ cmdline .add_argument ('--adv_conv_init' , type = str2bool , default = True ,
906
906
help = """init conv with MSRA initializer""" )
907
907
cmdline .add_argument ('--lr' , type = float ,
908
908
help = """Start learning rate""" )
@@ -927,7 +927,7 @@ def add_cli_args():
927
927
(decay by a factor at specified steps)
928
928
or `poly`(polynomial_decay with degree 2)""" )
929
929
930
- cmdline .add_argument ('--use_larc' , default = False ,
930
+ cmdline .add_argument ('--use_larc' , type = str2bool , default = False ,
931
931
help = """Use Layer wise Adaptive Rate Control
932
932
which helps convergence at really
933
933
large batch sizes""" )
@@ -950,7 +950,7 @@ def add_cli_args():
950
950
cmdline .add_argument ('--lc_beta' , default = 0.00001 , type = float ,
951
951
help = """Liner Cosine Beta""" )
952
952
953
- cmdline .add_argument ('--increased_aug' , default = False ,
953
+ cmdline .add_argument ('--increased_aug' , type = str2bool , default = False ,
954
954
help = """Increase augmentations helpful when training
955
955
with large number of GPUs such as 128 or 256""" )
956
956
cmdline .add_argument ('--contrast' , default = 0.6 , type = float ,
@@ -964,16 +964,16 @@ def add_cli_args():
964
964
help = """Brightness factor""" )
965
965
966
966
# tornasole arguments
967
- cmdline .add_argument ('--enable_tornasole' , default = False ,
967
+ cmdline .add_argument ('--enable_tornasole' , type = str2bool , default = False ,
968
968
help = """enable Tornasole""" )
969
969
cmdline .add_argument ('--tornasole_path' ,
970
970
default = 'tornasole_outputs/default_run' ,
971
971
help = """Directory in which to write tornasole data.
972
972
This can be a local path or
973
973
S3 path in the form s3://bucket_name/prefix_name""" )
974
- cmdline .add_argument ('--tornasole_save_all' , default = False ,
974
+ cmdline .add_argument ('--tornasole_save_all' , type = str2bool , default = False ,
975
975
help = """save all tensors""" )
976
- cmdline .add_argument ('--tornasole_dryrun' , default = False ,
976
+ cmdline .add_argument ('--tornasole_dryrun' , type = str2bool , default = False ,
977
977
help = """If enabled, do not write data to disk""" )
978
978
cmdline .add_argument ('--tornasole_exclude' , nargs = '+' , default = [],
979
979
type = str , action = 'append' ,
@@ -985,10 +985,10 @@ def add_cli_args():
985
985
Tornasole's default collection""" )
986
986
cmdline .add_argument ('--tornasole_step_interval' , default = 10 , type = int ,
987
987
help = """Save tornasole data every N runs""" )
988
- cmdline .add_argument ('--tornasole_save_weights' , default = False )
989
- cmdline .add_argument ('--tornasole_save_gradients' , default = False )
990
- cmdline .add_argument ('--tornasole_save_inputs' , default = False )
991
- cmdline .add_argument ('--tornasole_save_relu_activations' , default = False )
988
+ cmdline .add_argument ('--tornasole_save_weights' , type = str2bool , default = False )
989
+ cmdline .add_argument ('--tornasole_save_gradients' , type = str2bool , default = False )
990
+ cmdline .add_argument ('--tornasole_save_inputs' , type = str2bool , default = False )
991
+ cmdline .add_argument ('--tornasole_save_relu_activations' , type = str2bool , default = False )
992
992
cmdline .add_argument ('--tornasole_relu_reductions' , type = str ,
993
993
help = """A comma separated list of reductions can be
994
994
passed. If passed, saves relu activations
@@ -1022,18 +1022,17 @@ def get_tornasole_hook(FLAGS):
1022
1022
1023
1023
include_collections = []
1024
1024
1025
- if FLAGS .tornasole_save_weights :
1025
+ if FLAGS .tornasole_save_weights is True :
1026
1026
include_collections .append ('weights' )
1027
- if FLAGS .tornasole_save_gradients :
1027
+ if FLAGS .tornasole_save_gradients is True :
1028
1028
include_collections .append ('gradients' )
1029
- if FLAGS .tornasole_save_relu_activations :
1029
+ if FLAGS .tornasole_save_relu_activations is True :
1030
1030
include_collections .append ('relu_activations' )
1031
- if FLAGS .tornasole_save_inputs :
1031
+ if FLAGS .tornasole_save_inputs is True :
1032
1032
include_collections .append ('inputs' )
1033
- if FLAGS .tornasole_include :
1033
+ if FLAGS .tornasole_include is True :
1034
1034
ts .get_collection ('default' ).include (FLAGS .tornasole_include )
1035
1035
include_collections .append ('default' )
1036
-
1037
1036
return ts .TornasoleHook (out_dir = FLAGS .tornasole_path ,
1038
1037
save_config = ts .SaveConfig (
1039
1038
save_interval = FLAGS .tornasole_step_interval ),
@@ -1097,18 +1096,10 @@ def main():
1097
1096
1098
1097
logger = logging .getLogger (FLAGS .log_name )
1099
1098
logger .setLevel (logging .INFO ) # INFO, ERROR
1100
- # file handler which logs debug messages
1101
- # console handler
1102
- ch = logging .StreamHandler ()
1103
- ch .setLevel (logging .INFO )
1104
- # add formatter to the handlers
1105
- # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
1106
- formatter = logging .Formatter ('%(message)s' )
1107
- ch .setFormatter (formatter )
1108
- logger .addHandler (ch )
1109
1099
if not hvd .rank ():
1110
1100
fh = logging .FileHandler (os .path .join (FLAGS .log_dir , FLAGS .log_name ))
1111
1101
fh .setLevel (logging .DEBUG )
1102
+ formatter = logging .Formatter ('%(message)s' )
1112
1103
fh .setFormatter (formatter )
1113
1104
# add handlers to logger
1114
1105
logger .addHandler (fh )
@@ -1193,7 +1184,7 @@ def main():
1193
1184
'model' : FLAGS .model ,
1194
1185
'decay_steps' : decay_steps ,
1195
1186
'n_classes' : 1000 ,
1196
- 'dtype' : tf .float16 if FLAGS .fp16 else tf .float32 ,
1187
+ 'dtype' : tf .float16 if FLAGS .fp16 is True else tf .float32 ,
1197
1188
'format' : 'channels_first' ,
1198
1189
'device' : '/gpu:0' ,
1199
1190
'lr' : FLAGS .lr ,
@@ -1224,7 +1215,7 @@ def main():
1224
1215
save_checkpoints_steps = FLAGS .save_checkpoints_steps if do_checkpoint else None ,
1225
1216
keep_checkpoint_max = None ))
1226
1217
1227
- if FLAGS .enable_tornasole and hvd .rank () == 0 :
1218
+ if FLAGS .enable_tornasole is True and hvd .rank () == 0 :
1228
1219
hook = get_tornasole_hook (FLAGS )
1229
1220
1230
1221
if not FLAGS .eval :
@@ -1238,10 +1229,11 @@ def main():
1238
1229
num_training_samples ,
1239
1230
FLAGS .display_every ,
1240
1231
logger ))
1241
- if FLAGS .enable_tornasole :
1232
+ if FLAGS .enable_tornasole is True :
1242
1233
training_hooks .append (hook )
1243
1234
try :
1244
- hook .set_mode (ts .modes .TRAIN )
1235
+ if FLAGS .enable_tornasole is True :
1236
+ hook .set_mode (ts .modes .TRAIN )
1245
1237
start_time = time .time ()
1246
1238
classifier .train (
1247
1239
input_fn = lambda : make_dataset (
@@ -1278,7 +1270,6 @@ def main():
1278
1270
if (not FLAGS .eval_interval ) or \
1279
1271
(i % FLAGS .eval_interval != 0 ):
1280
1272
continue
1281
- hook .set_mode (ts .modes .EVAL )
1282
1273
eval_result = classifier .evaluate (
1283
1274
input_fn = lambda : make_dataset (
1284
1275
eval_filenames ,
0 commit comments