Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 2001dcc

Browse files
committed
Make sure that default indices increase continuously when reading in chunks
This fixes #24
1 parent 38659d1 commit 2001dcc

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ If the `chunksize` parameter is specified, `read_root` returns an iterator that
7676
for df in read_root('bigfile.root', chunksize=100000):
7777
# process df here
7878
```
79+
If `bigfile.root` doesn't contain an index, the default indices of the
80+
individual `DataFrame` chunks will still increase continuously, as if they were
81+
parts of a single large `DataFrame`.
7982

8083
You can also combine any of the above options at the same time.
8184

root_pandas/readwrite.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from numpy.lib.recfunctions import append_fields
8-
from pandas import DataFrame
8+
from pandas import DataFrame, RangeIndex
99
from root_numpy import root2array, list_trees
1010
from fnmatch import fnmatch
1111
from root_numpy import list_branches
@@ -199,11 +199,13 @@ def do_flatten(arr, flatten):
199199
# XXX could explicitly clean up the opened TFiles with TChain::Reset
200200

201201
def genchunks():
202+
current_index = 0
202203
for chunk in range(int(ceil(float(n_entries) / chunksize))):
203204
arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
204205
if flatten:
205206
arr = do_flatten(arr, flatten)
206-
yield convert_to_dataframe(arr)
207+
yield convert_to_dataframe(arr, start_index=current_index)
208+
current_index += len(arr)
207209
return genchunks()
208210

209211
arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs)
@@ -212,15 +214,17 @@ def genchunks():
212214
return convert_to_dataframe(arr)
213215

214216

215-
216-
def convert_to_dataframe(array):
217+
def convert_to_dataframe(array, start_index=None):
217218
nonscalar_columns = get_nonscalar_columns(array)
218219
if nonscalar_columns:
219220
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
220221
.format(bad_names=", ".join(nonscalar_columns)), UserWarning)
221222
indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names))
222223
if len(indices) == 0:
223-
df = DataFrame.from_records(array, exclude=nonscalar_columns)
224+
index = None
225+
if start_index is not None:
226+
index = RangeIndex(start=start_index, stop=start_index + len(array))
227+
df = DataFrame.from_records(array, exclude=nonscalar_columns, index=index)
224228
elif len(indices) == 1:
225229
# We store the index under the __index__* branch, where
226230
# * is the name of the index
@@ -235,7 +239,7 @@ def convert_to_dataframe(array):
235239
return df
236240

237241

238-
def to_root(df, path, key='default', mode='w', *args, **kwargs):
242+
def to_root(df, path, key='default', mode='w', store_index=True, *args, **kwargs):
239243
"""
240244
Write DataFrame to a ROOT file.
241245
@@ -247,6 +251,9 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
247251
Name of tree that the DataFrame will be saved as
248252
mode: string, {'w', 'a'}
249253
Mode that the file should be opened in (default: 'w')
254+
store_index: bool (optional, default: True)
255+
Whether the index of the DataFrame should be stored as
256+
an __index__* branch in the tree
250257
251258
Notes
252259
-----
@@ -270,11 +277,12 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
270277
from root_numpy import array2root
271278
# We don't want to modify the user's DataFrame here, so we make a shallow copy
272279
df_ = df.copy(deep=False)
273-
name = df_.index.name
274-
if name is None:
275-
# Handle the case where the index has no name
276-
name = ''
277-
df_['__index__' + name] = df_.index
280+
if store_index:
281+
name = df_.index.name
282+
if name is None:
283+
# Handle the case where the index has no name
284+
name = ''
285+
df_['__index__' + name] = df_.index
278286
arr = df_.to_records(index=False)
279287
array2root(arr, path, key, mode=mode, *args, **kwargs)
280288

tests/test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ def test_chunked_reading():
8282
assert count == 3
8383
os.remove('tmp.root')
8484

85+
# Make sure that the default index counts up properly,
86+
# even if the input is chunked
87+
def test_chunked_reading_consistent_index():
88+
df = pd.DataFrame({'x': [1,2,3,4,5,6]})
89+
df.to_root('tmp.root', store_index=False)
90+
91+
dfs = []
92+
for df_ in read_root('tmp.root', chunksize=2):
93+
dfs.append(df_)
94+
assert(not df_.empty)
95+
df_reconstructed = pd.concat(dfs)
96+
97+
assert_frame_equal(df, df_reconstructed)
98+
99+
os.remove('tmp.root')
100+
101+
85102
def test_multiple_files():
86103
df = pd.DataFrame({'x': [1,2,3,4,5,6]})
87104
df.to_root('tmp1.root')

0 commit comments

Comments
 (0)