Skip to content

Commit 1f0b16d

Browse files
vandanavkjarednielsen
authored andcommitted
Save scalar (aws#352)
* Write metrics to a file * Write contents to metric file * Enable the code to save_scalar * cache save_scalar before prepare_collections * Modify Minerva log file format * Test save_scalar * Remove TF save_scalar. Modify tests * Update test * Move up an assert * Remove xgboost changes for now * Fix CodeBuild * Change function names in test * Minerva file format change, Eureka SDK integration * Write loss to Minerva * Added some comments * Log scalars before closing write * Address review comments, add searchable scalars * Remove redundant code * Minor changes * Enable tensorboard * Add test for TF searchable scalar * Add some comments to the test * Fix CodeBuild * Use wrap_optimizer * Fix CodeBuild * Move metrics file writer out of tfevents * Keras TF save_scalar * Fix regression with Eureka integration * Keras TF save_scalar * Fix regression with Eureka integration * Flush out scalars before closing the file * Fix build error * close metrics writer after other writers * Address review comments * Correct path in the test * Correct file path * Fix regression * combine initialize_writer
1 parent 340fc57 commit 1f0b16d

File tree

14 files changed

+482
-22
lines changed

14 files changed

+482
-22
lines changed

smdebug/core/collection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CollectionKeys:
2727
LOSSES = "losses"
2828
BIASES = "biases"
2929
SCALARS = "scalars"
30+
SEARCHABLE_SCALARS = "searchable_scalars"
3031

3132
OPTIMIZER_VARIABLES = "optimizer_variables"
3233
TENSORFLOW_SUMMARIES = "tensorflow_summaries"
@@ -45,7 +46,12 @@ class CollectionKeys:
4546
# so we don't create summaries or reductions of these
4647
SUMMARIES_COLLECTIONS = {CollectionKeys.TENSORFLOW_SUMMARIES}
4748

48-
SCALAR_COLLECTIONS = {CollectionKeys.LOSSES, CollectionKeys.METRICS, CollectionKeys.SCALARS}
49+
SCALAR_COLLECTIONS = {
50+
CollectionKeys.LOSSES,
51+
CollectionKeys.METRICS,
52+
CollectionKeys.SCALARS,
53+
CollectionKeys.SEARCHABLE_SCALARS,
54+
}
4955

5056
# used by pt, mx, keras
5157
NON_REDUCTION_COLLECTIONS = SCALAR_COLLECTIONS.union(SUMMARIES_COLLECTIONS)

smdebug/core/config_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
CONFIG_INCLUDE_REGEX_KEY = "include_regex"
1616
CONFIG_SAVE_ALL_KEY = "save_all"
1717
TENSORBOARD_CONFIG_FILE_PATH_ENV_STR = "TENSORBOARD_CONFIG_FILE_PATH"
18+
DEFAULT_SAGEMAKER_METRICS_PATH = "SAGEMAKER_METRICS_DIRECTORY"
1819
DEFAULT_SAGEMAKER_OUTDIR = "/opt/ml/output/tensors"
1920
DEFAULT_SAGEMAKER_TENSORBOARD_PATH = "/opt/ml/input/config/tensorboardoutputconfig.json"
2021
DEFAULT_COLLECTIONS_FILE_NAME = "worker_0_collections.json"

smdebug/core/hook.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
from smdebug.core.utils import flatten, get_tb_worker, match_inc, size_and_shape
3333
from smdebug.core.writer import FileWriter
3434

35+
try:
36+
from smexperiments.metrics import SageMakerFileMetricsWriter
37+
except ImportError:
38+
from smdebug.core.metrics_file_writer import SageMakerFileMetricsWriter
39+
3540
logger = get_logger()
3641

3742

@@ -153,8 +158,15 @@ def __init__(
153158
self.mode = ModeKeys.GLOBAL
154159
self.mode_steps = {ModeKeys.GLOBAL: init_step}
155160
self.writer = None
161+
162+
self.metrics_writer = None
163+
156164
# Maps ModeKeys to FileWriter objects
157165
self.tb_writers = {}
166+
167+
# Cache scalars that are being saved through save_scalar() calls
168+
self.scalar_cache = []
169+
158170
self.logger.info("Saving to {}".format(self.out_dir))
159171
atexit.register(self._cleanup)
160172

@@ -309,6 +321,12 @@ def _close_writers(self) -> None:
309321
if self.dry_run:
310322
return
311323

324+
# flush out searchable scalars to metrics file
325+
if self.metrics_writer is not None:
326+
self._write_scalars()
327+
self.metrics_writer.close()
328+
self.metrics_writer = None
329+
312330
self._close_writer()
313331
to_delete_writers = []
314332

@@ -321,10 +339,11 @@ def _close_writers(self) -> None:
321339
for mode in to_delete_writers:
322340
del self.tb_writers[mode]
323341

324-
def _initialize_writer(self) -> None:
342+
def _initialize_writers(self) -> None:
325343
if self.dry_run:
326344
return
327345
self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)
346+
self.metrics_writer = SageMakerFileMetricsWriter()
328347

329348
def get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
330349
"""
@@ -470,6 +489,16 @@ def _write_scalar_summary(self, tensor_name, tensor_value, save_colls):
470489
f"so scalar summary could not be created"
471490
)
472491
break
492+
for s_col in save_colls:
493+
if s_col.name in [
494+
CollectionKeys.LOSSES,
495+
CollectionKeys.SEARCHABLE_SCALARS,
496+
CollectionKeys.METRICS,
497+
]:
498+
np_val = self._make_numpy_array(tensor_value)
499+
# Always log loss to Minerva
500+
tensor_val = np.mean(np_val)
501+
self.scalar_cache.append((tensor_name, tensor_val, True))
473502

474503
def _write_histogram_summary(self, tensor_name, tensor_value, save_collections):
475504
""" Maybe write to TensorBoard. """
@@ -490,18 +519,45 @@ def _write_histogram_summary(self, tensor_name, tensor_value, save_collections):
490519
)
491520
break
492521

522+
def _write_scalars(self):
523+
"""
524+
This function writes all the scalar values saved in the scalar_cache to file.
525+
If searchable is set to True for certain scalars, then that scalar is written to
526+
Minerva as well. By default, loss values are searchable.
527+
"""
528+
if self.writer is None:
529+
self._initialize_writers()
530+
tb_writer = self._maybe_get_tb_writer()
531+
for scalar_name, scalar_val, searchable in self.scalar_cache:
532+
save_collections = self._get_collections_with_tensor(scalar_name)
533+
logger.debug(
534+
f"Saving scalar {scalar_name} {scalar_val} for step {self.step} {self.mode} "
535+
f"{self.mode_steps[self.mode]}"
536+
)
537+
if searchable:
538+
self.metrics_writer.log_metric(scalar_name, scalar_val, self.mode_steps[self.mode])
539+
if tb_writer:
540+
self._write_raw_tensor(scalar_name, scalar_val, save_collections)
541+
self.scalar_cache = []
542+
493543
# Fix step number for saving scalar and tensor
494-
# def save_scalar(self, name, value):
495-
# get_collection(CollectionKeys.SCALARS).add_tensor_name(name)
496-
# if self.writer is None:
497-
# self._init_writer()
498-
# val = make_numpy_array(value)
499-
# if val.size != 1:
500-
# raise TypeError(
501-
# f'{name} has non scalar value of type: {type(value)}')
502-
# self._save_scalar_summary(name, val)
503-
# logger.debug(f'Saving scalar {name} {val} for step {self.step} {self.mode} {self.mode_steps[self.mode]}')
504-
# self._save_raw_tensor(name, val)
544+
def save_scalar(self, name, value, searchable=False):
545+
"""
546+
Call save_scalar at any point in the training script to log a scalar value,
547+
such as a metric or any other value.
548+
:param name: Name of the scalar. A prefix 'scalar/' will be added to it
549+
:param value: Scalar value
550+
:param searchable: True/False. If set to True, the scalar value will be written to
551+
SageMaker Minerva
552+
"""
553+
name = CallbackHook.SCALAR_PREFIX + name
554+
val = self._make_numpy_array(value)
555+
if val.size != 1:
556+
raise TypeError(f"{name} has non scalar value of type: {type(value)}")
557+
self.collection_manager.get(CollectionKeys.SCALARS).add_tensor_name(name)
558+
self.scalar_cache.append((name, val, searchable))
559+
if self.prepared_collections:
560+
self._write_scalars()
505561

506562
# def save_tensor(self, name, value):
507563
# # todo: support to add these tensors to any collection.
@@ -627,6 +683,7 @@ class CallbackHook(BaseHook):
627683
INPUT_TENSOR_SUFFIX = "_input_"
628684
OUTPUT_TENSOR_SUFFIX = "_output_"
629685
GRADIENT_PREFIX = "gradient/"
686+
SCALAR_PREFIX = "scalar/"
630687

631688
def __init__(
632689
self,

smdebug/core/metrics_file_writer.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Standard Library
2+
import json
3+
import os
4+
import time
5+
6+
# First Party
7+
from smdebug.core.config_constants import DEFAULT_SAGEMAKER_METRICS_PATH
8+
9+
METRICS_DIR = os.environ.get(DEFAULT_SAGEMAKER_METRICS_PATH, ".")
10+
11+
12+
class _RawMetricData(object):
13+
def __init__(self, metric_name, value, iteration_number, timestamp):
14+
self.MetricName = metric_name
15+
self.Value = value
16+
self.Timestamp = timestamp
17+
self.IterationNumber = iteration_number
18+
19+
20+
class SageMakerFileMetricsWriter(object):
21+
def __init__(self, filename=None):
22+
self._file = open(filename or self._metrics_file_name(), "a")
23+
self._indexes = {}
24+
self._closed = False
25+
26+
def _metrics_file_name(self):
27+
return "{}/{}.json".format(METRICS_DIR, str(os.getpid()))
28+
29+
def _write_metric_value(self, file, raw_metric_data):
30+
try:
31+
self._file.write(json.dumps(raw_metric_data.__dict__))
32+
self._file.write("\n")
33+
except AttributeError:
34+
if self._closed:
35+
raise ValueError("log_metric called on a closed writer")
36+
elif not self._file:
37+
self._file = open(self._metrics_file_name(), "a")
38+
self._file.write(json.dumps(raw_metric_data.__dict__))
39+
self._file.write("\n")
40+
else:
41+
raise
42+
43+
def log_metric(self, metric_name, value, iteration_number=None, timestamp=None):
44+
timestamp = int(round(time.time())) if timestamp is None else int(timestamp)
45+
resolved_index = int(
46+
self._indexes.get(metric_name, 0) if iteration_number is None else iteration_number
47+
)
48+
49+
value = float(value)
50+
assert isinstance(resolved_index, int)
51+
assert isinstance(timestamp, int)
52+
53+
self._write_metric_value(
54+
self._file, _RawMetricData(metric_name, value, iteration_number, timestamp)
55+
)
56+
if not iteration_number:
57+
self._indexes[metric_name] = resolved_index + 1
58+
59+
def close(self):
60+
if not self._closed and self._file:
61+
self._file.close()
62+
self._file = None
63+
self._closed = True
64+
65+
def __enter__(self):
66+
"""Return self"""
67+
return self
68+
69+
def __exit__(self, type, value, traceback):
70+
"""Execute self.close()"""
71+
self.close()
72+
73+
def __del__(self):
74+
"""Execute self.close()"""
75+
self.close()

smdebug/mxnet/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def _register_default_collections(self):
2525
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
2626
self.get(CollectionKeys.GRADIENTS).include("^gradient")
2727
self.get(CollectionKeys.LOSSES).include(".*loss")
28+
self.get(CollectionKeys.SCALARS).include("^scalar")
2829

2930
def create_collection(self, name):
3031
super().create_collection(name, cls=Collection)

smdebug/mxnet/hook.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CollectionKeys.BIASES,
1818
CollectionKeys.GRADIENTS,
1919
CollectionKeys.LOSSES,
20+
CollectionKeys.SCALARS,
2021
]
2122

2223

@@ -140,7 +141,7 @@ def forward_pre_hook(self, block, inputs):
140141
self._increment_step()
141142

142143
if self._get_collections_to_save_for_step():
143-
self._initialize_writer()
144+
self._initialize_writers()
144145

145146
if self.exported_model is False:
146147
self._export_model()

smdebug/pytorch/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _register_default_collections(self):
4040
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
4141
self.get(CollectionKeys.GRADIENTS).include("^gradient")
4242
self.get(CollectionKeys.LOSSES).include("[Ll]oss")
43+
self.get(CollectionKeys.SCALARS).include("^scalar")
4344

4445
def create_collection(self, name):
4546
super().create_collection(name, cls=Collection)

smdebug/pytorch/hook.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from smdebug.pytorch.singleton_utils import set_hook
1313
from smdebug.pytorch.utils import get_reduction_of_data, make_numpy_array
1414

15-
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]
15+
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES, CollectionKeys.SCALARS]
1616

1717

1818
class Hook(CallbackHook):
@@ -43,6 +43,10 @@ def __init__(
4343
include_collections=include_collections,
4444
save_all=save_all,
4545
)
46+
# We would like to collect loss collection
47+
# even if user does not specify any collections
48+
if CollectionKeys.LOSSES not in self.include_collections:
49+
self.include_collections.append(CollectionKeys.LOSSES)
4650
# mapping of module objects to their names,
4751
# useful in forward hook for logging input/output of modules
4852
self.module_maps = dict()
@@ -143,7 +147,7 @@ def forward_pre_hook(self, module, inputs):
143147
self._increment_step()
144148

145149
if self._get_collections_to_save_for_step():
146-
self._initialize_writer()
150+
self._initialize_writers()
147151
self.log_params(module)
148152

149153
if self.last_saved_step is not None and not self.exported_collections:

smdebug/tensorflow/base_hook.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@
2727
is_parameter_server_strategy,
2828
)
2929

30+
try:
31+
from smexperiments.metrics import SageMakerFileMetricsWriter
32+
except ImportError:
33+
from smdebug.core.metrics_file_writer import SageMakerFileMetricsWriter
34+
35+
3036
DEFAULT_INCLUDE_COLLECTIONS = [
3137
CollectionKeys.METRICS,
3238
CollectionKeys.LOSSES,
3339
CollectionKeys.SCALARS,
40+
CollectionKeys.SEARCHABLE_SCALARS,
3441
]
3542

3643

@@ -183,7 +190,7 @@ def get_writers(self, tensor_name, tensor_ref) -> List[FileWriter]:
183190
else:
184191
return [self.writer]
185192

186-
def _initialize_writer(self, only_initialize_if_missing=False) -> None:
193+
def _initialize_writers(self, only_initialize_if_missing=False) -> None:
187194
# In keras, sometimes we are not sure if writer is initialized
188195
# (such as metrics at end of epoch), that's why it passes the flag only_init_if_missing
189196

@@ -203,11 +210,19 @@ def _initialize_writer(self, only_initialize_if_missing=False) -> None:
203210
else:
204211
if self.writer is None or only_initialize_if_missing is False:
205212
self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)
213+
if self.metrics_writer is None or only_initialize_if_missing is False:
214+
self.metrics_writer = SageMakerFileMetricsWriter()
206215

207216
def _close_writer(self) -> None:
208217
if self.dry_run:
209218
return
210219

220+
# flush out searchable scalars to metrics file
221+
if self.metrics_writer is not None:
222+
self._write_scalars()
223+
self.metrics_writer.close()
224+
self.metrics_writer = None
225+
211226
if self.writer is not None:
212227
self.writer.flush()
213228
self.writer.close()

smdebug/tensorflow/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self, collections=None, create_default=True):
138138
CollectionKeys.INPUTS,
139139
CollectionKeys.OUTPUTS,
140140
CollectionKeys.ALL,
141+
CollectionKeys.SEARCHABLE_SCALARS,
141142
]:
142143
self.create_collection(n)
143144
self.get(CollectionKeys.BIASES).include("bias")

smdebug/tensorflow/keras.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _save_metrics(self, batch, logs, force_save=False):
316316
return
317317

318318
if force_save or self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
319-
self._initialize_writer(only_initialize_if_missing=True)
319+
self._initialize_writers(only_initialize_if_missing=True)
320320
logs["batch"] = batch
321321
for key in logs:
322322
if key in ["loss", "val_loss", "outputs"]:
@@ -326,7 +326,7 @@ def _save_metrics(self, batch, logs, force_save=False):
326326
self._save_for_tensor(key, logs[key], check_before_write=False)
327327

328328
if force_save or self._is_collection_being_saved_for_step(CollectionKeys.LOSSES):
329-
self._initialize_writer(only_initialize_if_missing=True)
329+
self._initialize_writers(only_initialize_if_missing=True)
330330
for key in ["loss", "val_loss"]:
331331
if key in logs:
332332
self._add_metric(metric_name=key)
@@ -442,7 +442,7 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
442442

443443
if self.tensor_refs_to_save_this_step:
444444
# if saving metric, writer may not be initialized as a result
445-
self._initialize_writer()
445+
self._initialize_writers()
446446

447447
self._add_callbacks(mode)
448448

smdebug/tensorflow/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _get_all_tensors_values(self, results):
306306

307307
def after_run(self, run_context, run_values):
308308
if self.tensors_to_save_this_step:
309-
self._initialize_writer()
309+
self._initialize_writers()
310310
for (tensor, value) in self._get_all_tensors_values(run_values.results):
311311
if tensor.dtype == tf.string:
312312
self._write_tf_summary(tensor, value)

0 commit comments

Comments
 (0)