Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit d7e2956

Browse files
KonstantinSchubertibab
authored andcommitted
Require user to specify which columns to flatten
This fixes #20.
1 parent 28f2097 commit d7e2956

File tree

2 files changed

+74
-28
lines changed

2 files changed

+74
-28
lines changed

root_pandas/readwrite.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def expand_braces(orig):
5050
return list(set(res))
5151

5252

53+
def get_nonscalar_columns(array):
54+
first_row = array[0]
55+
bad_cols = np.array([x.ndim != 0 for x in first_row])
56+
col_names = np.array(array.dtype.names)
57+
bad_names = col_names[bad_cols]
58+
return list(bad_names)
59+
5360
def get_matching_variables(branches, patterns, fail=True):
5461
selected = []
5562

@@ -85,9 +92,9 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
8592
If this parameter is specified, an iterator is returned that yields DataFrames with `chunksize` rows.
8693
where: str
8794
Only rows that match the expression will be read.
88-
flatten: bool
89-
If set to True, will use root_numpy.stretch to flatten arrays in the root file into individual entries.
90-
All arrays specified in the columns must have the same length for this to work.
95+
flatten: sequence of str
96+
A sequence of column names. Will use root_numpy.stretch to flatten arrays in the specified columns into
97+
individual entries. All arrays specified in the columns must have the same length for this to work.
9198
Be careful if you combine this with chunksize, as chunksize will refer to the number of unflattened entries,
9299
so you will be iterating over a number of entries that is potentially larger than chunksize.
93100
The index of each element within its former array will be saved in the __array_index column.
@@ -143,8 +150,19 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
143150
for var in ignored:
144151
all_vars.remove(var)
145152

146-
def do_flatten(arr):
147-
arr_, idx = stretch(arr, return_indices=True)
153+
def do_flatten(arr, flatten):
154+
if flatten is True:
155+
warnings.warn(" The option flatten=True is deprecated. Please specify the branches you would like "
156+
"to flatten in a list: flatten=['foo', 'bar']", FutureWarning)
157+
arr_, idx = stretch(arr, return_indices=True)
158+
else:
159+
nonscalar = get_nonscalar_columns(arr)
160+
fields = [x for x in arr.dtype.names if (x not in nonscalar or x in flatten)]
161+
will_drop = [x for x in arr.dtype.names if x not in fields]
162+
if will_drop:
163+
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
164+
.format(bad_names=", ".join(will_drop)), UserWarning)
165+
arr_, idx = stretch(arr, fields=fields, return_indices=True)
148166
arr = append_fields(arr_, '__array_index', idx, usemask=False, asrecarray=True)
149167
return arr
150168

@@ -159,31 +177,22 @@ def genchunks():
159177
for chunk in range(int(ceil(float(n_entries) / chunksize))):
160178
arr = root2array(paths, key, all_vars, start=chunk * chunksize, stop=(chunk+1) * chunksize, selection=where, *args, **kwargs)
161179
if flatten:
162-
arr = do_flatten(arr)
180+
arr = do_flatten(arr, flatten)
163181
yield convert_to_dataframe(arr)
164-
165182
return genchunks()
166183

167184
arr = root2array(paths, key, all_vars, selection=where, *args, **kwargs)
168185
if flatten:
169-
arr = do_flatten(arr)
186+
arr = do_flatten(arr, flatten)
170187
return convert_to_dataframe(arr)
171188

172189

173190

174191
def convert_to_dataframe(array):
175-
176-
def get_nonscalar_columns(array):
177-
first_row = array[0]
178-
bad_cols = np.array([x.ndim != 0 for x in first_row])
179-
col_names = np.array(array.dtype.names)
180-
bad_names = col_names[bad_cols]
181-
if not bad_names.size == 0:
182-
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
183-
.format(bad_names=", ".join(bad_names)), UserWarning)
184-
return list(bad_names)
185-
186192
nonscalar_columns = get_nonscalar_columns(array)
193+
if nonscalar_columns:
194+
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
195+
.format(bad_names=", ".join(nonscalar_columns)), UserWarning)
187196
indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names))
188197
if len(indices) == 0:
189198
df = DataFrame.from_records(array, exclude=nonscalar_columns)

tests/test.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import ROOT
88
import os
9+
import warnings
910

1011
def test_read_write():
1112
df = pd.DataFrame({'x': [1,2,3]})
@@ -110,34 +111,69 @@ def test_flatten():
110111

111112
length = np.array([3])
112113
x = np.array([0, 1, 2], dtype='float64')
114+
y = np.array([6, 7, 8], dtype='float64')
113115
tt.Branch('length', length, 'length/I')
114116
tt.Branch('x', x, 'x[length]/D')
115-
117+
tt.Branch('y', y, 'y[length]/D')
116118
tt.Fill()
117119
x[0] = 3
118120
x[1] = 4
119121
x[2] = 5
122+
y[0] = 9
123+
y[1] = 10
124+
y[2] = 11
120125
tt.Fill()
121126

122127
tf.Write()
123128
tf.Close()
124129

125130
branches = list_branches('tmp.root')
126131

127-
df_ = read_root('tmp.root', flatten=True)
128132

133+
# flatten one out of two array branches
134+
with warnings.catch_warnings():
135+
warnings.simplefilter("ignore")
136+
df_ = read_root('tmp.root', flatten=['x'])
129137
assert('__array_index' in df_.columns)
130138
assert(len(df_) == 6)
139+
assert('length' in df_.columns.values)
140+
assert('x' in df_.columns.values)
141+
assert('y' not in df_.columns.values)
131142
assert(np.all(df_['__array_index'] == np.array([0, 1, 2, 0, 1, 2])))
143+
assert(np.all(df_['x'] == np.array([0, 1, 2, 3, 4, 5])))
132144

133-
# Also flatten chunked data
134145

135-
for df_ in read_root('tmp.root', flatten=True, chunksize=1):
146+
# flatten both array branches
147+
df_ = read_root('tmp.root', flatten=['x','y'])
148+
assert('__array_index' in df_.columns)
149+
assert(len(df_) == 6)
150+
assert(np.all(df_['__array_index'] == np.array([0, 1, 2, 0, 1, 2])))
151+
assert('length' in df_.columns.values)
152+
assert('x' in df_.columns.values)
153+
assert('y' in df_.columns.values)
154+
assert(np.all(df_['x'] == np.array([0, 1, 2, 3, 4, 5])))
155+
assert(np.all(df_['y'] == np.array([6, 7, 8, 9, 10, 11])))
156+
157+
158+
# Also flatten chunked data
159+
for df_ in read_root('tmp.root', flatten=['x'], chunksize=1):
136160
assert(len(df_) == 3)
137161
assert(np.all(df_['__array_index'] == np.array([0, 1, 2])))
138162

163+
# Also test deprecated behaviour
164+
with warnings.catch_warnings():
165+
warnings.simplefilter("ignore")
166+
df_ = read_root('tmp.root', flatten=True)
167+
assert('__array_index' in df_.columns)
168+
assert(len(df_) == 6)
169+
assert(np.all(df_['__array_index'] == np.array([0, 1, 2, 0, 1, 2])))
170+
171+
139172
os.remove('tmp.root')
140173

174+
175+
176+
141177
def test_drop_nonscalar_columns():
142178
array = np.array([1, 2, 3])
143179
matrix = np.array([[1, 2, 3], [4, 5, 6]])
@@ -157,11 +193,12 @@ def test_drop_nonscalar_columns():
157193

158194
path = 'tmp.root'
159195
array2root(arr, path, 'ntuple', mode='recreate')
160-
161-
df = read_root(path, flatten=False)
162-
# the above line throws an error if flatten=True because nonscalar columns
163-
# are dropped only after the flattening is applied. However, the flattening
164-
# algorithm can not deal with arrays of more than one dimension.
196+
with warnings.catch_warnings():
197+
warnings.simplefilter("ignore")
198+
df = read_root(path, flatten=False)
199+
# the above line throws an error if flatten=True because nonscalar columns
200+
# are dropped only after the flattening is applied. However, the flattening
201+
# algorithm can not deal with arrays of more than one dimension.
165202
assert(len(df.columns) == 2)
166203
assert(np.all(df.index.values == np.array([0, 1])))
167204
assert(np.all(df.a.values == np.array([3, 2])))

0 commit comments

Comments
 (0)