Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 6abd81d

Browse files
alexpearceibab
authored andcommitted
Add support for reading in multiple files at once
This fixes #10
1 parent 45ab34d commit 6abd81d

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

root_pandas/readwrite.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ def get_matching_variables(branches, patterns, fail=True):
6161
return selected
6262

6363

64-
def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=None, flatten=False, *args, **kwargs):
64+
def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=None, flatten=False, *args, **kwargs):
6565
"""
66-
Read a ROOT file into a pandas DataFrame.
66+
Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
6767
Further *args and *kwargs are passed to root_numpy's root2array.
6868
If the root file contains a branch matching __index__*, it will become the DataFrame's index.
6969
7070
Parameters
7171
----------
72-
path: string
73-
The path to the root file.
72+
paths: string or list
73+
The path(s) to the root file(s)
7474
key: string
7575
The key of the tree to load.
7676
columns: str or sequence of str
@@ -98,16 +98,22 @@ def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=N
9898
>>> df = read_root('test.root', 'MyTree', columns=['A{B,C}*', 'D'], where='ABB > 100')
9999
100100
"""
101+
102+
if not isinstance(paths, list):
103+
paths = [paths]
104+
# Use a single file to search for trees and branches
105+
seed_path = paths[0]
106+
101107
if not key:
102-
trees = list_trees(path)
108+
trees = list_trees(seed_path)
103109
if len(trees) == 1:
104110
key = trees[0]
105111
elif len(trees) == 0:
106-
raise ValueError('No trees found in {}'.format(path))
112+
raise ValueError('No trees found in {}'.format(seed_path))
107113
else:
108-
raise ValueError('More than one tree found in {}'.format(path))
114+
raise ValueError('More than one tree found in {}'.format(seed_path))
109115

110-
branches = list_branches(path, key)
116+
branches = list_branches(seed_path, key)
111117

112118
if not columns:
113119
all_vars = branches
@@ -139,20 +145,22 @@ def do_flatten(arr):
139145
return arr
140146

141147
if chunksize:
142-
f = ROOT.TFile.Open(path)
143-
n_entries = f.Get(key).GetEntries()
144-
f.Close()
148+
tchain = ROOT.TChain(key)
149+
for path in paths:
150+
tchain.Add(path)
151+
n_entries = tchain.GetEntries()
152+
# XXX could explicitly clean up the opened TFiles with TChain::Reset
145153

146154
def genchunks():
147155
for chunk in range(int(ceil(float(n_entries) / chunksize))):
148-
arr = root2array(path, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
156+
arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
149157
if flatten:
150158
arr = do_flatten(arr)
151159
yield convert_to_dataframe(arr)
152160

153161
return genchunks()
154162

155-
arr = root2array(path, key, all_vars, selection=where, *args, **kwargs)
163+
arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs)
156164
if flatten:
157165
arr = do_flatten(arr)
158166
return convert_to_dataframe(arr)

tests/test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,29 @@ def test_chunked_reading():
4141
assert count == 3
4242
os.remove('tmp.root')
4343

44+
def test_multiple_files():
45+
df = pd.DataFrame({'x': [1,2,3,4,5,6]})
46+
df.to_root('tmp1.root')
47+
df.to_root('tmp2.root')
48+
df.to_root('tmp3.root')
49+
50+
df_ = read_root(['tmp1.root', 'tmp2.root', 'tmp3.root'])
51+
52+
assert(len(df_) == 3 * len(df))
53+
54+
# Also test chunked read of multiple files
55+
56+
counter = 0
57+
for df_ in read_root(['tmp1.root', 'tmp2.root', 'tmp3.root'], chunksize=3):
58+
assert(len(df_) == 3)
59+
counter += 1
60+
assert(counter == 6)
61+
62+
os.remove('tmp1.root')
63+
os.remove('tmp2.root')
64+
os.remove('tmp3.root')
65+
66+
4467
def test_flatten():
4568
tf = ROOT.TFile('tmp.root', 'RECREATE')
4669
tt = ROOT.TTree("a", "a")

0 commit comments

Comments
 (0)