Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 3b41ae9

Browse files
committed
Implement flattening arrays when reading
This implementation should be enough to fix #7 for now. We can make this more flexible later on, for example allowing the user to flatten different arrays at once, producing all combinations of entries.
1 parent 77a77a1 commit 3b41ae9

File tree

3 files changed

+172
-10
lines changed

3 files changed

+172
-10
lines changed

root_pandas/__init__.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
A module that extends pandas to support the ROOT data format.
44
"""
55

6+
import numpy as np
7+
from numpy.lib.recfunctions import append_fields
68
from pandas import DataFrame
79
from root_numpy import root2array, list_trees
810
from fnmatch import fnmatch
@@ -13,6 +15,8 @@
1315
import re
1416
import ROOT
1517

18+
from .utils import stretch
19+
1620

1721
__all__ = ['read_root']
1822

@@ -57,7 +61,7 @@ def get_matching_variables(branches, patterns, fail=True):
5761
return selected
5862

5963

60-
def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=None, *args, **kwargs):
64+
def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=None, flatten=False, *args, **kwargs):
6165
"""
6266
Read a ROOT file into a pandas DataFrame.
6367
Further *args and *kwargs are passed to root_numpy's root2array.
@@ -66,17 +70,23 @@ def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=N
6670
Parameters
6771
----------
6872
path: string
69-
The path to the root file
73+
The path to the root file.
7074
key: string
7175
The key of the tree to load.
7276
columns: str or sequence of str
7377
A sequence of shell-patterns (can contain *, ?, [] or {}). Matching columns are read.
7478
ignore: str or sequence of str
75-
A sequence of shell-patterns (can contain *, ?, [] or {}). All matching columns are ignored (overriding the columns argument)
79+
A sequence of shell-patterns (can contain *, ?, [] or {}). All matching columns are ignored (overriding the columns argument).
7680
chunksize: int
77-
If this parameter is specified, an iterator is returned that yields DataFrames with `chunksize` rows
81+
If this parameter is specified, an iterator is returned that yields DataFrames with `chunksize` rows.
7882
where: str
79-
Only rows that match the expression will be read
83+
Only rows that match the expression will be read.
84+
flatten: bool
85+
If set to True, will use root_numpy.stretch to flatten arrays in the root file into individual entries.
86+
All arrays specified in the columns must have the same length for this to work.
87+
Be careful if you combine this with chunksize, as chunksize will refer to the number of unflattened entries,
88+
so you will be iterating over a number of entries that is potentially larger than chunksize.
89+
The index of each element within its former array will be saved in the __array_index column.
8090
8191
Returns
8292
-------
@@ -89,10 +99,10 @@ def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=N
8999
90100
"""
91101
if not key:
92-
branches = list_trees(path)
93-
if len(branches) == 1:
94-
key = branches[0]
95-
elif len(branches) == 0:
102+
trees = list_trees(path)
103+
if len(trees) == 1:
104+
key = trees[0]
105+
elif len(trees) == 0:
96106
raise ValueError('No trees found in {}'.format(path))
97107
else:
98108
raise ValueError('More than one tree found in {}'.format(path))
@@ -123,18 +133,28 @@ def read_root(path, key=None, columns=None, ignore=None, chunksize=None, where=N
123133
for var in ignored:
124134
all_vars.remove(var)
125135

136+
def do_flatten(arr):
137+
arr_, idx = stretch(arr, return_indices=True)
138+
arr = append_fields(arr_, '__array_index', idx, usemask=False, asrecarray=True)
139+
return arr
140+
126141
if chunksize:
127-
f = ROOT.TFile(path)
142+
f = ROOT.TFile.Open(path)
128143
n_entries = f.Get(key).GetEntries()
129144
f.Close()
130145

131146
def genchunks():
132147
for chunk in range(int(ceil(float(n_entries) / chunksize))):
133148
arr = root2array(path, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
149+
if flatten:
150+
arr = do_flatten(arr)
134151
yield convert_to_dataframe(arr)
152+
135153
return genchunks()
136154

137155
arr = root2array(path, key, all_vars, selection=where, *args, **kwargs)
156+
if flatten:
157+
arr = do_flatten(arr)
138158
return convert_to_dataframe(arr)
139159

140160

root_pandas/utils.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) 2012 rootpy developers and contributors
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
4+
# this software and associated documentation files (the "Software"), to deal in
5+
# the Software without restriction, including without limitation the rights to
6+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7+
# the Software, and to permit persons to whom the Software is furnished to do so,
8+
# subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all
11+
# copies or substantial portions of the Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17+
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19+
#
20+
#
21+
# Code temporarily copied from the root_numpy package
22+
#
23+
24+
import numpy as np
25+
VLEN = np.vectorize(len)
26+
27+
def stretch(arr, fields=None, return_indices=False):
28+
"""Stretch an array.
29+
Stretch an array by ``hstack()``-ing multiple array fields while
30+
preserving column names and record array structure. If a scalar field is
31+
specified, it will be stretched along with array fields.
32+
Parameters
33+
----------
34+
arr : NumPy structured or record array
35+
The array to be stretched.
36+
fields : list of strings, optional (default=None)
37+
A list of column names to stretch. If None, then stretch all fields.
38+
return_indices : bool, optional (default=False)
39+
If True, the array index of each stretched array entry will be
40+
returned in addition to the stretched array.
41+
This changes the return type of this function to a tuple consisting
42+
of a structured array and a numpy int64 array.
43+
Returns
44+
-------
45+
ret : A NumPy structured array
46+
The stretched array.
47+
Examples
48+
--------
49+
>>> import numpy as np
50+
>>> from root_numpy import stretch
51+
>>> arr = np.empty(2, dtype=[('scalar', np.int), ('array', 'O')])
52+
>>> arr[0] = (0, np.array([1, 2, 3], dtype=np.float))
53+
>>> arr[1] = (1, np.array([4, 5, 6], dtype=np.float))
54+
>>> stretch(arr, ['scalar', 'array'])
55+
array([(0, 1.0), (0, 2.0), (0, 3.0), (1, 4.0), (1, 5.0), (1, 6.0)],
56+
dtype=[('scalar', '<i8'), ('array', '<f8')])
57+
"""
58+
dtype = []
59+
len_array = None
60+
61+
if fields is None:
62+
fields = arr.dtype.names
63+
64+
# Construct dtype and check consistency
65+
for field in fields:
66+
dt = arr.dtype[field]
67+
if dt == 'O' or len(dt.shape):
68+
if dt == 'O':
69+
# Variable-length array field
70+
lengths = VLEN(arr[field])
71+
else:
72+
lengths = np.repeat(dt.shape[0], arr.shape[0])
73+
# Fixed-length array field
74+
if len_array is None:
75+
len_array = lengths
76+
elif not np.array_equal(lengths, len_array):
77+
raise ValueError(
78+
"inconsistent lengths of array columns in input")
79+
if dt == 'O':
80+
dtype.append((field, arr[field][0].dtype))
81+
else:
82+
dtype.append((field, arr[field].dtype, dt.shape[1:]))
83+
else:
84+
# Scalar field
85+
dtype.append((field, dt))
86+
87+
if len_array is None:
88+
raise RuntimeError("no array column in input")
89+
90+
# Build stretched output
91+
ret = np.empty(np.sum(len_array), dtype=dtype)
92+
for field in fields:
93+
dt = arr.dtype[field]
94+
if dt == 'O' or len(dt.shape) == 1:
95+
# Variable-length or 1D fixed-length array field
96+
ret[field] = np.hstack(arr[field])
97+
elif len(dt.shape):
98+
# Multidimensional fixed-length array field
99+
ret[field] = np.vstack(arr[field])
100+
else:
101+
# Scalar field
102+
ret[field] = np.repeat(arr[field], len_array)
103+
104+
if return_indices:
105+
idx = np.concatenate(list(map(np.arange, len_array)))
106+
return ret, idx
107+
108+
return ret

tests/test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from root_pandas import read_root
33
from root_numpy import list_branches
44
from pandas.util.testing import assert_frame_equal
5+
import numpy as np
6+
import ROOT
57
import os
68

79
def test_read_write():
@@ -39,3 +41,35 @@ def test_chunked_reading():
3941
assert count == 3
4042
os.remove('tmp.root')
4143

44+
def test_flatten():
45+
tf = ROOT.TFile('tmp.root', 'RECREATE')
46+
tt = ROOT.TTree("a", "a")
47+
48+
length = np.array([3])
49+
x = np.array([0, 1, 2], dtype='float64')
50+
tt.Branch('length', length, 'length/I')
51+
tt.Branch('x', x, 'x[length]/D')
52+
53+
tt.Fill()
54+
x[0] = 3
55+
x[1] = 4
56+
x[2] = 5
57+
tt.Fill()
58+
59+
tf.Write()
60+
tf.Close()
61+
62+
branches = list_branches('tmp.root')
63+
64+
df_ = read_root('tmp.root', flatten=True)
65+
66+
assert('__array_index' in df_.columns)
67+
assert(len(df_) == 6)
68+
assert(np.all(df_['__array_index'] == np.array([0, 1, 2, 0, 1, 2])))
69+
70+
# Also flatten chunked data
71+
72+
for df_ in read_root('tmp.root', flatten=True, chunksize=1):
73+
assert(len(df_) == 3)
74+
assert(np.all(df_['__array_index'] == np.array([0, 1, 2])))
75+

0 commit comments

Comments
 (0)