Skip to content

Commit a125850

Browse files
authored
Fix create_trial brk when many elements are read in the first attempt (aws#168)
1 parent cdfbdd1 commit a125850

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

smdebug/core/index_reader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class ReadIndexFilesCache:
5454
5555
Note: cache_limit is a soft limit.
5656
57-
In certain cases, the size of the cache can exceed this limit.
57+
The size of the cache can exceed this limit if eviction_point is 0.
58+
This can happen if start_after_key is None (unset) or is equal to the first element in the cache.
5859
5960
If start_after_key happens to the be first element in sorted(self.lookup_set), then we do not
6061
evict any element from the cache, but add more elements to cache.
@@ -76,6 +77,9 @@ def has_not_read(self, index_file: str) -> bool:
7677

7778
def _evict_cache(self, start_after_key: str) -> None:
7879
read_files = sorted(self.lookup_set)
80+
start_after_key = "" if start_after_key is None else start_after_key
81+
# eviction_point = 0 if start_after_key is None (unset).
82+
# This happens if more than self.cache_limit number of files are read on the first read attempt.
7983
eviction_point = bisect_left(read_files, start_after_key)
8084
for i in range(eviction_point):
8185
self.lookup_set.remove(read_files[i])

tests/core/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ def test_index_files_cache():
9898
) # Elements in the cache will be file_4, file_5, file_6
9999

100100

101+
def test_index_files_cache_insert_many_elements_in_the_first_read():
102+
cache = ReadIndexFilesCache()
103+
cache.cache_limit = 5
104+
elements = ["a", "b", "c", "d", "e", "f", "g", "h"]
105+
for e in elements:
106+
cache.add(e, None)
107+
108+
# No files should be evicted because start_after_key has not been set
109+
assert len(cache.lookup_set) == len(elements)
110+
111+
101112
def test_get_prefix_from_index_file():
102113
local_index_filepath = "/opt/ml/testing/run_1/index/000000000/000000000000_worker_0.json"
103114
prefix = IndexFileLocationUtils.get_prefix_from_index_file(local_index_filepath)

0 commit comments

Comments
 (0)