Skip to content

Commit 2ada1d4

Browse files
NihalHarishVikas-kum
authored andcommitted
Fix has_passed_steps and tests for has-passed-steps with mode steps (aws#147)
* tests for has-passed-step-tests
1 parent bdf55e4 commit 2ada1d4

File tree

2 files changed

+148
-2
lines changed

2 files changed

+148
-2
lines changed

smdebug/trials/trial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,13 @@ def has_passed_step(self, step, mode=ModeKeys.GLOBAL) -> StepState:
517517
"""
518518
all_steps = self.steps(mode=mode, show_incomplete_steps=True)
519519
bisect_idx = bisect_left(all_steps, step)
520-
g_step = self._global_step_currently(mode, step)
521520

522521
if bisect_idx < len(all_steps):
522+
# This returns either the global step corresponding to the mode-step
523+
# or the closest global step that is greater than the step passed as a parameter
524+
g_step = self._global_step_currently(mode, all_steps[bisect_idx])
523525
if all_steps[bisect_idx] > step:
524-
if self.last_complete_step > g_step:
526+
if self.last_complete_step >= g_step:
525527
return StepState.UNAVAILABLE
526528
return StepState.NOT_YET_AVAILABLE
527529
elif all_steps[bisect_idx] == step:

tests/analysis/trials/test_has_passed_step_scenarios.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,53 @@
11
# Standard Library
2+
import json
3+
import os
4+
import shutil
5+
import uuid
6+
from pathlib import Path
27

38
# Third Party
49
import pytest
510

611
# First Party
12+
from smdebug.core.collection_manager import CollectionManager
713
from smdebug.core.config_constants import INCOMPLETE_STEP_WAIT_WINDOW_KEY
14+
from smdebug.core.locations import IndexFileLocationUtils
15+
from smdebug.core.modes import ModeKeys
816
from smdebug.core.tensor import StepState
917
from smdebug.exceptions import NoMoreData, StepUnavailable
1018
from smdebug.trials import create_trial
1119

1220

21+
def dummy_trial_creator(trial_dir, num_workers, job_ended):
22+
Path(trial_dir).mkdir(parents=True, exist_ok=True)
23+
cm = CollectionManager()
24+
for i in range(num_workers):
25+
collection_file_name = f"worker_{i}_collections.json"
26+
cm.export(trial_dir, collection_file_name)
27+
if job_ended:
28+
Path(os.path.join(trial_dir, "training_job_end.ts")).touch()
29+
30+
31+
def dummy_step_creator(trial_dir, global_step, mode, mode_step, worker_name):
32+
static_step_data = (
33+
'{"meta": {"mode": "TRAIN", "mode_step": 0, "event_file_name": ""}, '
34+
'"tensor_payload": ['
35+
'{"tensorname": "gradients/dummy:0", "start_idx": 0, "length": 1}'
36+
"]}"
37+
)
38+
39+
step = json.loads(static_step_data)
40+
step["meta"]["mode"] = mode
41+
step["meta"]["mode_step"] = mode_step
42+
43+
index_file_location = IndexFileLocationUtils.get_index_key_for_step(
44+
trial_dir, global_step, worker_name
45+
)
46+
Path(os.path.dirname(index_file_location)).mkdir(parents=True, exist_ok=True)
47+
with open(index_file_location, "w") as f:
48+
json.dump(step, f)
49+
50+
1351
@pytest.mark.slow
1452
def test_single_writer_all_steps_written_complete_job():
1553
"""Test Scenario Description"
@@ -38,6 +76,112 @@ def test_single_writer_all_steps_written_complete_job():
3876
assert trial.last_complete_step == 6
3977

4078

79+
@pytest.mark.slow
80+
def test_single_writer_all_steps_written_complete_job_two_modes():
81+
"""Test Scenario Description"
82+
workers : [a]
83+
modes: TRAIN, EVAL
84+
steps :{
85+
0: [worker:a, mode: TRAIN, mode_step: 0],
86+
10: [worker:a, mode: TRAIN, mode_step: 10],
87+
20: [worker:a, mode: TRAIN, mode_step: 20],
88+
30: [worker:a, mode: TRAIN, mode_step: 30],
89+
40: [worker:a, mode: EVAL, mode_step: 0],
90+
50: [worker:a, mode: EVAL, mode_step: 10],
91+
60: [worker:a, mode: EVAL, mode_step: 20],
92+
70: [worker:a, mode: EVAL, mode_step: 30]
93+
}
94+
END_OF_JOB.ts --> Present
95+
"""
96+
97+
path = os.path.join("ts_output/train/", str(uuid.uuid4()))
98+
dummy_trial_creator(trial_dir=path, num_workers=1, job_ended=True)
99+
for i in range(0, 31, 10):
100+
dummy_step_creator(
101+
trial_dir=path, global_step=i, mode="TRAIN", mode_step=i, worker_name="worker_0"
102+
)
103+
104+
for i in range(0, 31, 10):
105+
dummy_step_creator(
106+
trial_dir=path, global_step=i + 40, mode="EVAL", mode_step=i, worker_name="worker_0"
107+
)
108+
109+
trial = create_trial(path)
110+
num_workers = len(trial.workers())
111+
assert num_workers == 1
112+
assert trial.loaded_all_steps is True
113+
all_steps = trial.steps(show_incomplete_steps=True)
114+
completed_steps = trial.steps()
115+
assert all_steps == [0, 10, 20, 30, 40, 50, 60, 70]
116+
assert completed_steps == all_steps
117+
assert trial.has_passed_step(30) == StepState.AVAILABLE
118+
assert trial.has_passed_step(23, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
119+
assert trial.has_passed_step(40, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
120+
assert trial.has_passed_step(30, mode=ModeKeys.EVAL) == StepState.AVAILABLE
121+
assert trial.has_passed_step(23, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
122+
assert trial.has_passed_step(80) == StepState.UNAVAILABLE
123+
assert trial.has_passed_step(80, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
124+
assert trial.has_passed_step(80, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
125+
assert trial.last_index_token == os.path.join(
126+
path, "index/000000000/000000000070_worker_0.json"
127+
)
128+
assert trial.last_complete_step == 70
129+
shutil.rmtree(path, ignore_errors=True)
130+
131+
132+
@pytest.mark.slow
133+
def test_single_writer_all_steps_written_incomplete_job_two_modes():
134+
"""Test Scenario Description"
135+
workers : [a]
136+
modes: TRAIN, EVAL
137+
steps :{
138+
0: [worker:a, mode: TRAIN, mode_step: 0],
139+
10: [worker:a, mode: TRAIN, mode_step: 10],
140+
20: [worker:a, mode: TRAIN, mode_step: 20],
141+
30: [worker:a, mode: TRAIN, mode_step: 30],
142+
40: [worker:a, mode: EVAL, mode_step: 0],
143+
50: [worker:a, mode: EVAL, mode_step: 10],
144+
60: [worker:a, mode: EVAL, mode_step: 20],
145+
70: [worker:a, mode: EVAL, mode_step: 30]
146+
}
147+
END_OF_JOB.ts --> Absent
148+
"""
149+
150+
path = os.path.join("ts_output/train/", str(uuid.uuid4()))
151+
dummy_trial_creator(trial_dir=path, num_workers=1, job_ended=False)
152+
for i in range(0, 31, 10):
153+
dummy_step_creator(
154+
trial_dir=path, global_step=i, mode="TRAIN", mode_step=i, worker_name="worker_0"
155+
)
156+
157+
for i in range(0, 31, 10):
158+
dummy_step_creator(
159+
trial_dir=path, global_step=i + 40, mode="EVAL", mode_step=i, worker_name="worker_0"
160+
)
161+
162+
trial = create_trial(path)
163+
num_workers = len(trial.workers())
164+
assert num_workers == 1
165+
assert trial.loaded_all_steps is False
166+
all_steps = trial.steps(show_incomplete_steps=True)
167+
completed_steps = trial.steps()
168+
assert all_steps == [0, 10, 20, 30, 40, 50, 60, 70]
169+
assert completed_steps == all_steps
170+
assert trial.has_passed_step(30) == StepState.AVAILABLE
171+
assert trial.has_passed_step(23, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
172+
assert trial.has_passed_step(40, mode=ModeKeys.TRAIN) == StepState.NOT_YET_AVAILABLE
173+
assert trial.has_passed_step(30, mode=ModeKeys.EVAL) == StepState.AVAILABLE
174+
assert trial.has_passed_step(23, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
175+
assert trial.has_passed_step(80) == StepState.NOT_YET_AVAILABLE
176+
assert trial.has_passed_step(80, mode=ModeKeys.TRAIN) == StepState.NOT_YET_AVAILABLE
177+
assert trial.has_passed_step(80, mode=ModeKeys.EVAL) == StepState.NOT_YET_AVAILABLE
178+
assert trial.last_index_token == os.path.join(
179+
path, "index/000000000/000000000070_worker_0.json"
180+
)
181+
assert trial.last_complete_step == 70
182+
shutil.rmtree(path, ignore_errors=True)
183+
184+
41185
@pytest.mark.slow
42186
def test_single_writer_all_steps_written_incomplete_job():
43187
"""Test Scenario Description"

0 commit comments

Comments
 (0)