Skip to content

Commit a99e163

Browse files
NihalHarishrahul003
authored andcommitted
Stop Waiting For Collection Files If Training Has Ended (aws#51)
* stop waiting if training has ended * fix incorrect merge * Fail if collection files missing
1 parent 57bb732 commit a99e163

File tree

3 files changed

+89
-12
lines changed

3 files changed

+89
-12
lines changed

smdebug/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@ def __str__(self):
1111
return "Step {} of mode {} not yet available".format(self.step, self.mode.name)
1212

1313

14+
class MissingCollectionFiles(Exception):
15+
def __init__(self):
16+
pass
17+
18+
def __str__(self):
19+
return "Training job has ended. All the collection files could not be loaded"
20+
21+
1422
class IndexReaderException(Exception):
1523
def __init__(self, message):
1624
self.message = message

smdebug/trials/trial.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
match_inc,
2626
serialize_tf_device,
2727
)
28-
from smdebug.exceptions import NoMoreData, StepUnavailable, TensorUnavailable
28+
from smdebug.exceptions import (
29+
MissingCollectionFiles,
30+
NoMoreData,
31+
StepUnavailable,
32+
TensorUnavailable,
33+
)
2934

3035

3136
class Trial(ABC):
@@ -149,22 +154,21 @@ def _fetch():
149154
"Waiting to read collections files generated by the training job."
150155
)
151156

152-
def _wait_for_first_collection_file():
153-
while len(collection_files) == 0:
154-
time.sleep(2)
155-
_fetch()
156-
157-
def _wait_for_all_collection_files():
158-
while len(collection_files) < self.num_workers:
157+
def _wait_for_collection_files(number_of_collection_file_to_wait_for):
158+
while len(collection_files) < number_of_collection_file_to_wait_for:
159159
time.sleep(2)
160160
_fetch()
161-
for collection_file in collection_files:
162-
self.worker_set.add(get_worker_name_from_collection_file(collection_file))
161+
if has_training_ended(self.path):
162+
""" _fetch should have returned all the collection files if the training job has ended """
163+
if len(collection_files) < number_of_collection_file_to_wait_for:
164+
raise MissingCollectionFiles
163165

164166
_fetch()
165-
_wait_for_first_collection_file()
167+
_wait_for_collection_files(1) # wait for the first collection file
166168
self._read_collections(collection_files)
167-
_wait_for_all_collection_files()
169+
_wait_for_collection_files(self.num_workers) # wait for all the collection files
170+
for collection_file in collection_files:
171+
self.worker_set.add(get_worker_name_from_collection_file(collection_file))
168172

169173
@abstractmethod
170174
def _load_tensors_from_index_tensors(self, index_tensors_dict):
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Standard Library
2+
3+
# Third Party
4+
import pytest
5+
6+
# First Party
7+
from smdebug.exceptions import MissingCollectionFiles
8+
from smdebug.trials import create_trial
9+
10+
11+
@pytest.mark.slow
12+
def test_load_collection_files_from_completed_job():
13+
"""
14+
Number of collection files : 2001
15+
Training_has_ended.ts : Present
16+
17+
All the collection files have been written in the test dataset
18+
and the training_has_ended file is present
19+
:return:
20+
"""
21+
path = "s3://tornasole-testing/collection-tests/all-collection-files-present/"
22+
try:
23+
trial = create_trial(path)
24+
except MissingCollectionFiles:
25+
assert False
26+
assert len(trial.workers()) == 2001
27+
28+
29+
@pytest.mark.slow
30+
def test_load_collection_files_from_completed_job_with_missing_files():
31+
"""
32+
Number of collection files : 1446
33+
Training_has_ended.ts : Present
34+
35+
Some of the collection files have been removed in the test dataset.
36+
The number of expected collection files is supposed to 2001
37+
but the training_has_ended file is present so we stop waiting
38+
:return:
39+
"""
40+
path = "s3://tornasole-testing/collection-tests/collection-files-missing/"
41+
try:
42+
trial = create_trial(path)
43+
assert False
44+
except MissingCollectionFiles:
45+
assert True
46+
47+
48+
@pytest.mark.slow
49+
def test_load_collection_files_from_incomplete_job():
50+
"""
51+
Number of collection files : 2001
52+
Training_has_ended.ts : Absent
53+
54+
All the collection files have been written in the test dataset
55+
and the training_has_ended file is absent
56+
57+
58+
:return:
59+
"""
60+
path = "s3://tornasole-testing/collection-tests/all-collection-files-present-job-incomplete/"
61+
try:
62+
trial = create_trial(path)
63+
except MissingCollectionFiles:
64+
assert False
65+
assert len(trial.workers()) == 2001

0 commit comments

Comments
 (0)