Skip to content

Commit 524aba9

Browse files
authored
Fix out dir bug (aws#160)
* fix out dir bug * print mode.name instead of mode * print mode.name instead of mode * print mode.name instead of mode
1 parent 568733a commit 524aba9

File tree

6 files changed

+23
-11
lines changed

6 files changed

+23
-11
lines changed

tornasole/core/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def step(self, step_num, mode=ModeKeys.GLOBAL):
168168
avail_steps = self.trial.available_steps(mode=mode)
169169
if len(avail_steps) > 0:
170170
last_step = avail_steps[-1]
171-
raise NoMoreData("Looking for step:{} for mode {} and reached end of training. Max step available is {}".format(step_num, mode, last_step))
171+
raise NoMoreData(step_num, mode, last_step)
172172
raise StepNotYetAvailable(step_num, mode)
173173
assert False, 'Should not happen'
174174

tornasole/exceptions.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ def __init__(self, step, mode):
77
self.mode = mode
88

99
def __str__(self):
10-
return 'Step {} of mode {} not yet available'.format(self.step, self.mode)
10+
return 'Step {} of mode {} not yet available'.format(self.step,
11+
self.mode.name)
1112

1213

1314
class StepUnavailable(Exception):
@@ -17,7 +18,7 @@ def __init__(self, step, mode):
1718

1819
def __str__(self):
1920
return 'Step {} of mode {} is not available as it was not saved'\
20-
.format(self.step, self.mode)
21+
.format(self.step, self.mode.name)
2122

2223

2324
class TensorUnavailableForStep(Exception):
@@ -36,6 +37,7 @@ def __str__(self):
3637
'You might want to query for the reductions.'
3738
return msg
3839

40+
3941
class TensorUnavailable(Exception):
4042
def __init__(self, tname):
4143
self.tname = tname
@@ -46,7 +48,17 @@ def __str__(self):
4648

4749

4850
class NoMoreData(Exception):
49-
pass
51+
def __init__(self, step, mode, last_step):
52+
self.step = step
53+
self.mode = mode
54+
self.last_step = last_step
55+
56+
self.msg = "Looking for step {} of mode {} and reached " \
57+
"end of training. Max step available is {}"\
58+
.format(self.step, self.mode.name, self.last_step)
59+
60+
def __str__(self):
61+
return self.msg
5062

5163

5264
class RuleEvaluationConditionMet(Exception):

tornasole/mxnet/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(self,
3737
include_collections=DEFAULT_INCLUDE_COLLECTIONS,
3838
save_all=False):
3939
self.out_dir = verify_and_get_out_dir(out_dir)
40-
self.out_base_dir = os.path.dirname(out_dir)
41-
self.run_id = os.path.basename(out_dir)
40+
self.out_base_dir = os.path.dirname(self.out_dir)
41+
self.run_id = os.path.basename(self.out_dir)
4242
self.include_collections = include_collections
4343

4444
self.dry_run = dry_run

tornasole/pytorch/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(self,
3737
include_collections=DEFAULT_INCLUDE_COLLECTIONS,
3838
save_all=False):
3939
self.out_dir = verify_and_get_out_dir(out_dir)
40-
self.out_base_dir = os.path.dirname(out_dir)
41-
self.run_id = os.path.basename(out_dir)
40+
self.out_base_dir = os.path.dirname(self.out_dir)
41+
self.run_id = os.path.basename(self.out_dir)
4242
self.include_collections = include_collections
4343

4444
self.dry_run = dry_run

tornasole/tensorflow/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def __init__(self, out_dir=None,
7070
they are all saved in the collection `all`
7171
"""
7272
self.out_dir = verify_and_get_out_dir(out_dir)
73-
self.out_base_dir = os.path.dirname(out_dir)
74-
self.run_id = os.path.basename(out_dir)
73+
self.out_base_dir = os.path.dirname(self.out_dir)
74+
self.run_id = os.path.basename(self.out_dir)
7575

7676
self.dry_run = dry_run
7777
self.worker = worker if worker is not None else socket.gethostname()

tornasole/trials/trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def wait_for_steps(self, required_steps, mode=ModeKeys.GLOBAL):
257257
avail_steps = self.available_steps(mode=mode)
258258
if len(avail_steps) > 0:
259259
last_step = avail_steps[-1]
260-
raise NoMoreData("Looking for step:{} for mode {} and reached end of training. Max step available for mode is {}".format(step, mode, last_step))
260+
raise NoMoreData(step, mode, last_step)
261261
time.sleep(5)
262262

263263
def has_passed_step(self, step, mode=ModeKeys.GLOBAL):

0 commit comments

Comments
 (0)