47
47
48
48
49
49
class ScalarCache (object ):
50
- def __init__ (self , scalar_name , scalar_val , sm_metric , write_tb , write_event ):
50
+ def __init__ (self , scalar_name , scalar_val , mode , sm_metric , write_tb , write_event ):
51
51
self .name = scalar_name
52
52
self .value = scalar_val
53
+ self .mode = mode
53
54
self .sm_metric = sm_metric
54
55
self .write_tb = write_tb
55
56
self .write_event = write_event
@@ -440,6 +441,10 @@ def _increment_step(self):
440
441
441
442
self .step += 1
442
443
self .mode_steps [self .mode ] += 1
444
+
445
+ # Increment Global step number irrespective of what mode it is
446
+ if self .mode != ModeKeys .GLOBAL :
447
+ self .mode_steps [ModeKeys .GLOBAL ] = self .step
443
448
self ._collections_to_save_for_step = None
444
449
445
450
def _write_state (self ):
@@ -564,12 +569,15 @@ def _write_scalars(self):
564
569
for scalar_obj in self .scalar_cache :
565
570
scalar_name = scalar_obj .name
566
571
scalar_val = scalar_obj .value
572
+ scalar_mode = scalar_obj .mode
567
573
sm_metric = scalar_obj .sm_metric
568
574
write_tb = scalar_obj .write_tb
569
575
write_event = scalar_obj .write_event
570
576
if self .metrics_writer and sm_metric :
571
577
self .metrics_writer .log_metric (
572
- scalar_name , scalar_val , iteration_number = self .mode_steps [self .mode ]
578
+ scalar_name + "_" + scalar_mode .name ,
579
+ scalar_val ,
580
+ iteration_number = self .mode_steps [scalar_mode ],
573
581
)
574
582
if write_tb :
575
583
tb_writer = self ._maybe_get_tb_writer ()
@@ -596,7 +604,7 @@ def save_scalar(self, name, value, sm_metric=False):
596
604
val = self ._make_numpy_array (value )
597
605
if val .size != 1 :
598
606
raise TypeError (f"{ name } has non scalar value of type: { type (value )} " )
599
- scalar_obj = ScalarCache (name , val , sm_metric = True , write_tb = True , write_event = True )
607
+ scalar_obj = ScalarCache (name , val , self . mode , sm_metric , write_tb = True , write_event = True )
600
608
self .scalar_cache .append (scalar_obj )
601
609
602
610
def _write_raw_tensor (self , tensor_name , tensor_value , save_collections , tensor_ref = None ):
@@ -657,7 +665,12 @@ def _save_for_tensor(self, tensor_name, tensor_value, check_before_write=True):
657
665
# Always log loss to Minerva
658
666
tensor_val = np .mean (np_val )
659
667
scalar_obj = ScalarCache (
660
- tensor_name , tensor_val , sm_metric = True , write_tb = False , write_event = False
668
+ tensor_name ,
669
+ tensor_val ,
670
+ self .mode ,
671
+ sm_metric = True ,
672
+ write_tb = False ,
673
+ write_event = False ,
661
674
)
662
675
self .scalar_cache .append (scalar_obj )
663
676
0 commit comments