Skip to content

Commit 9319924

Browse files
authored
Add support to save mode and mode step in the summary (#38)
* core mode changes * update with changes to writer and reader * update tensor reader * simplified remove tensor as tf is using its own method now * implemented except when_nan * tensor can match multiple collections * save manager working, and tensor reduction change * delete save manager, move to new PR * Revert "delete save manager, move to new PR" This reverts commit 475122de1c5613c800fdbfc9abdb1c2d24a64add. * update to support mxnet as well * Changes for mxnet * remove default * reorg and fix tests * Added string constant * add comment and an allowed_modes list
1 parent 4ce12e5 commit 9319924

19 files changed

+315
-34
lines changed

tests/test_collections.py renamed to tests/core/test_collections.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def test_load_empty():
1515

1616
def test_manager_export_load():
1717
cm = CollectionManager()
18+
cm.create_collection('default')
1819
cm.get('default').include('loss')
1920
cm.add(Collection('trial1'))
2021
cm.add('trial2')
@@ -25,6 +26,7 @@ def test_manager_export_load():
2526

2627
def test_manager():
2728
cm = CollectionManager()
29+
cm.create_collection('default')
2830
cm.get('default').include('loss')
2931
cm.add(Collection('trial1'))
3032
cm.add('trial2')
File renamed without changes.
File renamed without changes.

tests/core/test_modes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from tornasole_core.writer import FileWriter
2+
from tornasole_core.reader import FileReader
3+
import numpy as np
4+
from tornasole_core.modes import ModeKeys
5+
from datetime import datetime
6+
import glob
7+
import shutil
8+
9+
def test_mode_writing():
10+
run_id = 'trial_' + datetime.now().strftime('%Y%m%d-%H%M%S%f')
11+
for s in range(0, 10):
12+
13+
fw = FileWriter(logdir='ts_outputs', trial=run_id, step=s)
14+
if s % 2 == 0:
15+
fw.write_tensor(tdata=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
16+
tname='arr', mode=ModeKeys.TRAIN, mode_step=s//2)
17+
else:
18+
fw.write_tensor(tdata=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
19+
tname='arr', mode=ModeKeys.EVAL, mode_step=s // 2)
20+
fw.close()
21+
files = glob.glob('ts_outputs/' + run_id + '/**/*.tfevents',
22+
recursive=True)
23+
for f in files:
24+
fr = FileReader(fname=f)
25+
for tu in fr.read_tensors():
26+
tensor_name, step, tensor_data, mode, mode_step = tu
27+
if step % 2 == 0:
28+
assert mode == ModeKeys.TRAIN
29+
else:
30+
assert mode == ModeKeys.EVAL
31+
assert mode_step == step // 2
32+
shutil.rmtree('ts_outputs/' + run_id)

tests/test_numpy.py renamed to tests/core/test_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def test_s3():
3434
my_session = boto3.session.Session()
3535
my_region = my_session.region_name
3636
my_account = boto3.client('sts').get_caller_identity().get('Account')
37-
bucket_name = 'sagemaker-{}-{}'.format(my_region,my_account)
38-
key_name = 'tornasole/{}'.format(str(uuid.uuid4()))
37+
bucket_name = 'tornasole-testing'
38+
key_name = 'core-tests/tornasole/{}'.format(str(uuid.uuid4()))
3939
#sagemaker-us-east-1-722321484884
4040
location = 's3://{}/{}'.format(bucket_name,key_name)
4141
print("Saving to Location")
File renamed without changes.
File renamed without changes.
File renamed without changes.

tornasole_core/collection.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,10 @@ def add_tensor(self, t):
8787
self.tensors.append(t)
8888

8989
def remove_tensor(self, t):
90-
# have to compare names because tensors can have variables, \
91-
# we don't want to end up comparing tensors and variables
9290
if t.name in self.tensor_names:
93-
found_index = None
94-
for i, lt in enumerate(self.tensors):
95-
if lt.name == t.name:
96-
found_index = i
97-
assert found_index is not None
98-
self.tensors.pop(found_index)
9991
self.tensor_names.remove(t.name)
92+
if t in self.tensors:
93+
self.tensors.remove(t)
10094

10195
def add_reduction_tensor(self, s):
10296
self.reduction_tensor_names.append(s.name)

tornasole_core/collection_manager.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,19 @@ class CollectionManager:
88
It contains a default collection into which tensors are inserted
99
without specifying collection name
1010
"""
11-
def __init__(self, create_default=True):
11+
def __init__(self):
1212
self.collections = {}
13-
if create_default:
14-
self.collections['default'] = self.get_new_collection('default')
1513

16-
def get_new_collection(self, name):
17-
return Collection(name)
14+
def create_collection(self, name):
15+
self.collections[name] = Collection(name)
1816

1917
def get_collections(self):
2018
return self.collections
2119

2220
def add(self, arg):
2321
if isinstance(arg, str):
2422
if arg not in self.collections:
25-
self.collections[arg] = self.get_new_collection(arg)
23+
self.create_collection(arg)
2624
elif isinstance(arg, Collection):
2725
if arg.name not in self.collections:
2826
self.collections[arg.name] = arg
@@ -43,7 +41,7 @@ def export(self, filename):
4341

4442
@staticmethod
4543
def load(filename):
46-
cm = CollectionManager(create_default=False)
44+
cm = CollectionManager()
4745
with open(filename, 'r') as f:
4846
line = f.readline()
4947
while line:
@@ -54,7 +52,7 @@ def load(filename):
5452

5553
@staticmethod
5654
def load_from_string(s):
57-
cm = CollectionManager(create_default=False)
55+
cm = CollectionManager()
5856
lines = s.split('\n')
5957
for line in lines:
6058
c = Collection.load(line.rstrip())

tornasole_core/modes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from enum import Enum
2+
3+
# Note that Keras has similar concept of ModeKeys
4+
class ModeKeys(Enum):
5+
TRAIN = 1 #training/fitting mode
6+
EVAL = 2 # testing/evaluation mode
7+
PREDICT = 3 # prediction/inference mode
8+
GLOBAL = 4
9+
10+
ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]
11+
MODE_STEP_PLUGIN_NAME = "mode_step"
12+
MODE_PLUGIN_NAME = "mode"

tornasole_core/save_config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,36 @@
11
SAVE_CONFIG_VERSION_NUM = 'v0'
2+
from .modes import ModeKeys as modes
3+
4+
5+
class SaveConfigModes:
6+
def __init__(self, mode_save_configs=None):
7+
if mode_save_configs is None:
8+
mode_save_configs = {}
9+
self.mode_save_configs = mode_save_configs
10+
11+
def add_for_all_modes(self, save_config):
12+
for mode in modes:
13+
self.mode_save_configs[mode] = save_config
14+
15+
def add(self, mode, save_config):
16+
self.mode_save_configs[mode] = save_config
17+
18+
def should_save_step(self, mode, step_num):
19+
return self.mode_save_configs[mode].should_save_step(step_num)
20+
21+
def add_when_nan_tensor(self, tensor):
22+
for mode in modes:
23+
self.mode_save_configs[mode].when_nan_tensors.append(tensor)
24+
25+
def get_save_config(self, mode):
26+
return self.mode_save_configs[mode]
27+
28+
@staticmethod
29+
def create_simple_save_mode(save_config):
30+
sm = SaveConfigModes()
31+
sm.add_for_all_modes(save_config)
32+
return sm
33+
234

335
class SaveConfig:
436
"""
@@ -31,6 +63,9 @@ def __init__(self, save_interval=100, skip_num_steps=0, save_steps=None, when_na
3163
self.skip_num_steps = skip_num_steps
3264
self.when_nan = when_nan if when_nan is not None else []
3365

66+
# will be populated by hook
67+
self.when_nan_tensors = []
68+
3469
def export(self):
3570
separator = '%'
3671
list_separator = ','

tornasole_core/save_manager.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from .save_config import SaveConfig, SaveConfigModes
2+
from .utils import match_inc
3+
4+
class SaveManager:
5+
def __init__(self, collection_manager, include_collections_names,
6+
default_reduction_config,
7+
default_save_config):
8+
self.configs_for_collections = {}
9+
if isinstance(default_save_config, SaveConfig):
10+
sm = SaveConfigModes.create_simple_save_mode(default_save_config)
11+
self.default_save_modes = sm
12+
elif isinstance(default_save_config, SaveConfigModes):
13+
self.default_save_modes = default_save_config
14+
elif isinstance(default_save_config, dict):
15+
self.default_save_modes = SaveConfigModes(default_save_config)
16+
else:
17+
raise TypeError('save_config can only be a SaveConfig instance, or '
18+
'a dictionary mapping from mode '
19+
'to SaveConfig instance.')
20+
self.default_reduction_config = default_reduction_config
21+
self.collection_manager = collection_manager
22+
self.include_collections_names = include_collections_names
23+
self.save_collections = []
24+
self.save_states_cache = {}
25+
# todo clear cache for old steps
26+
self.tensor_to_collection = {}
27+
self.when_nan_tensors = {}
28+
29+
def prepare(self):
30+
# below is to control the order
31+
# in which these collections appear in save_collections
32+
for cname in ['weights', 'gradients', 'bias', 'optimizer_variables']:
33+
if self._should_collection_be_saved(cname):
34+
self.save_collections.append(self.collection_manager.get(cname))
35+
36+
# adding other collections to save_collections
37+
for c_name, c in self.collection_manager.get_collections().items():
38+
if self._should_collection_be_saved(c_name) \
39+
and c not in self.save_collections:
40+
self.save_collections.append(c)
41+
42+
for c_name, c in self.collection_manager.get_collections().items():
43+
if c.save_config is not None:
44+
if isinstance(c.save_config, dict):
45+
self.configs_for_collections[c_name] = c.save_config
46+
elif isinstance(c.save_config, SaveConfig):
47+
sm = SaveConfigModes.create_simple_save_mode(c.save_config)
48+
self.configs_for_collections[c_name] = sm
49+
else:
50+
raise TypeError('collection {} has save config of wrong type {}'
51+
.format(c_name, type(c.save_config)))
52+
else:
53+
self.configs_for_collections[c_name] = self.default_save_modes
54+
55+
if c.reduction_config is None and self.default_reduction_config is not None:
56+
c.reduction_config = self.default_reduction_config
57+
58+
def _should_collection_be_saved(self, coll_name):
59+
return coll_name in self.include_collections_names
60+
61+
def get_all_collections_to_save(self):
62+
return self.save_collections
63+
64+
def collections_to_save(self, mode, step):
65+
if (mode, step) not in self.save_states_cache:
66+
collection_save_state = {}
67+
for coll in self.save_collections:
68+
sm = self.configs_for_collections[coll.name]
69+
rv = sm.should_save_step(mode, step)
70+
if any(rv.values()):
71+
collection_save_state[coll.name] = rv
72+
self.save_states_cache[(mode, step)] = collection_save_state
73+
return self.save_states_cache[(mode, step)]
74+
75+
def get_save_config(self, collection, mode):
76+
return self.configs_for_collections[collection.name].get_save_config(mode)
77+
78+
def from_collections(self, tensor_name):
79+
# for tf this will be prepopulated because of prepare_tensors
80+
if not tensor_name in self.tensor_to_collection:
81+
# for mxnet it is computed and then cached
82+
matched_colls = []
83+
for coll in self.save_collections:
84+
if tensor_name in coll.tensor_names:
85+
matched_colls.append(coll)
86+
elif match_inc(tensor_name, coll.get_include_regex()):
87+
matched_colls.append(coll)
88+
self.tensor_to_collection[tensor_name] = matched_colls
89+
return self.tensor_to_collection[tensor_name]
90+
91+
def should_save_tensor(self, tensorname, mode, step):
92+
# returns dictionary with two keys:
93+
# if value for step is true in the dict, then we are saving this tensor
94+
# because we have hit the step to save this
95+
# if value for when_nan is true, we are considering saving this tensor
96+
# because this tensor might be saved if some other tensor is nan
97+
colls = self.from_collections(tensorname)
98+
final_ss = {'step': False, 'when_nan': False}
99+
ss_colls = self.collections_to_save(mode, step)
100+
for c in colls:
101+
if c.name in ss_colls:
102+
ss = ss_colls[c.name]
103+
final_ss['step'] = final_ss['step'] or ss['step']
104+
final_ss['when_nan'] = final_ss['when_nan'] or ss['when_nan']
105+
return final_ss
106+
107+
# below are used only by TF
108+
def prepare_tensors(self):
109+
for c_name, c in self.collection_manager.get_collections().items():
110+
if c_name == 'when_nan':
111+
continue
112+
if c not in self.save_collections:
113+
continue
114+
for t in c.tensors + c.reduction_tensors:
115+
self._add_tensor_to_collection(t, c)
116+
117+
def _add_tensor_to_collection(self, t, c):
118+
if t.name not in self.tensor_to_collection:
119+
self.tensor_to_collection[t.name] = [c]
120+
else:
121+
self.tensor_to_collection[t.name].append(c)
122+
123+
def add_when_nan_tensor(self, collection, tensor):
124+
self.configs_for_collections[collection.name].add_when_nan_tensor(tensor)
125+
if tensor.name not in self.when_nan_tensors:
126+
self.when_nan_tensors[tensor.name] = []
127+
self.when_nan_tensors[tensor.name].append(collection)
128+
self._add_tensor_to_collection(tensor, collection)
129+
130+
if 'when_nan' not in self.collection_manager.collections:
131+
self.collection_manager.create_collection('when_nan')
132+
self.collection_manager.get('when_nan').add_tensor(tensor)
133+
134+
def is_when_nan_tensor(self, tensor_name):
135+
return tensor_name in self.when_nan_tensors
136+
137+
def when_nan_collections(self, tensor_name):
138+
return self.when_nan_tensors[tensor_name]

tornasole_core/tfevent/event_file_reader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from .summary_pb2 import Summary, SummaryMetadata
2929

3030
from tornasole_core.tfrecord.record_reader import RecordReader
31+
from tornasole_core.modes import ModeKeys, MODE_STEP_PLUGIN_NAME, MODE_PLUGIN_NAME
32+
3133

3234
#todo: remove this logger perhaps
3335
logging.basicConfig()
@@ -133,7 +135,7 @@ def __exit__(self,exc_type, exc_value, traceback):
133135
self._ev_reader.__exit__(exc_type, exc_value, traceback)
134136

135137
def read_tensors(self, read_data=False, check=False):
136-
for (step,summ) in self.read_summaries(check=check):
138+
for step, summ in self.read_summaries(check=check):
137139
for v in summ.value:
138140
assert v.WhichOneof('value') == 'tensor'
139141
tensor_name = v.tag
@@ -142,7 +144,17 @@ def read_tensors(self, read_data=False, check=False):
142144
tensor_data = get_tensor_data(v.tensor)
143145
else:
144146
tensor_data = None
145-
yield (tensor_name, step, tensor_data)
147+
148+
# default values
149+
# todo: validate the logic extensively
150+
mode_step = step
151+
mode = ModeKeys.GLOBAL
152+
for metadata in v.metadata.plugin_data:
153+
if metadata.plugin_name == MODE_STEP_PLUGIN_NAME:
154+
mode_step = int(metadata.content)
155+
if metadata.plugin_name == MODE_PLUGIN_NAME:
156+
mode = ModeKeys(int(metadata.content))
157+
yield (tensor_name, step, tensor_data, mode, mode_step)
146158

147159
def read_summaries(self, check=True):
148160
for ev in self.read_events(check=check):

0 commit comments

Comments
 (0)