Skip to content

Commit 9eee151

Browse files
NihalHarishrahul003
authored andcommitted
mode writer support (aws#144)
1 parent fbbc239 commit 9eee151

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

tests/core/test_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def test_index():
3232
line_count = 0
3333
i = 0
3434
for row in csv_reader:
35-
count = int(row[2])
35+
count = int(row[-2])
3636
fo.seek(count, 0)
37-
end = int(row[3])
37+
end = int(row[-1])
3838
line = fo.read(end)
3939
zoo = open("test.txt", "wb")
4040
zoo.write(line)

tornasole/core/indexutils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
11
class TensorLocation:
2-
def __init__(self, tname, event_file_name, start_idx, length):
2+
def __init__(self, tname, mode, mode_step, event_file_name, start_idx, length):
33
self.tensorname = tname
4+
self.mode = mode
5+
self.mode_step = mode_step
46
self.event_file_name = event_file_name
57
self.start_idx = start_idx
68
self.length = length
79

810
def serialize(self):
9-
return format(f'{self.tensorname}, {self.event_file_name}, {self.start_idx},{self.length}')
11+
return format(
12+
f'{self.tensorname},'
13+
f'{self.mode}, '
14+
f'{self.mode_step}, '
15+
f'{self.event_file_name}, '
16+
f'{self.start_idx}, '
17+
f'{self.length}')
1018

1119
@staticmethod
1220
def deserialize(manifest_line_str, manifest_key_name):
1321
arr = manifest_line_str.split(",")
14-
return TensorLocation(arr[0], arr[1], arr[2], arr[3])
22+
return TensorLocation(arr[0],
23+
arr[1],
24+
arr[2],
25+
arr[3],
26+
arr[4],
27+
arr[5])
1528

1629

1730
class IndexUtil:

tornasole/core/tfevent/event_file_writer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def write_tensor(self, tdata, tname, write_index,
226226
s = Summary(value=[Summary.Value(tag=tag, metadata=smd,
227227
tensor=tensor_proto)])
228228
if write_index:
229-
self.write_summary_with_index(s, self.step, tname)
229+
self.write_summary_with_index(s, self.step, tname, mode, mode_step)
230230
else:
231231
self.write_summary(s, self.step)
232232

@@ -236,11 +236,11 @@ def write_summary(self, summary, step):
236236
event.step = step
237237
self.write_event(event)
238238

239-
def write_summary_with_index(self, summary, step, tname):
239+
def write_summary_with_index(self, summary, step, tname, mode, mode_step):
240240
event = Event(summary=summary)
241241
event.wall_time = time.time()
242242
event.step = step
243-
return self.write_event(IndexArgs(event, tname))
243+
return self.write_event(IndexArgs(event, tname, mode, mode_step))
244244

245245
def write_event(self, event):
246246
"""Adds an event to the event file."""
@@ -302,10 +302,12 @@ def run(self):
302302
if isinstance(event_in_queue, IndexArgs):
303303
tname = event_in_queue.tensorname
304304
eventfile = self._ev_writer.name()
305+
mode = event_in_queue.get_mode()
306+
mode_step = event_in_queue.get_mode_step()
305307
s3, _, _ = is_s3(eventfile)
306308
if not s3:
307309
eventfile = os.path.abspath(self._ev_writer.name())
308-
tensorlocation = TensorLocation(tname, eventfile, positions[0], positions[1])
310+
tensorlocation = TensorLocation(tname, mode, mode_step, eventfile, positions[0], positions[1])
309311
self._ev_writer.indexwriter.add_index(tensorlocation)
310312
# Flush the event writer every so often.
311313
now = time.time()

tornasole/core/tfevent/index_file_writer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tornasole.core.access_layer.s3 import TSAccessS3
33
from tornasole.core.utils import is_s3
44

5+
56
class IndexWriter(object):
67
def __init__(self, file_path):
78
self.file_path = file_path
@@ -38,13 +39,22 @@ def close(self):
3839
self.writer.close()
3940
self.writer = None
4041

42+
4143
class IndexArgs(object):
42-
def __init__(self, event, tensorname):
44+
def __init__(self, event, tensorname, mode, mode_step):
4345
self.event = event
4446
self.tensorname = tensorname
47+
self.mode = mode
48+
self.mode_step = mode_step
4549

4650
def get_event(self):
4751
return self.event
4852

4953
def get_tensorname(self):
5054
return self.tensorname
55+
56+
def get_mode(self):
57+
return str(self.mode).split('.')[-1]
58+
59+
def get_mode_step(self):
60+
return self.mode_step

0 commit comments

Comments
 (0)