Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 28f2097

Browse files
KonstantinSchubertibab
authored andcommitted
Exclude columns that are not one-dimensional (#19)
Exclude non-scalar columns and print a warning
1 parent 7ab1a9c commit 28f2097

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

root_pandas/readwrite.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from math import ceil
1515
import re
1616
import ROOT
17+
import warnings
1718

1819
from .utils import stretch
1920

@@ -169,14 +170,27 @@ def genchunks():
169170
return convert_to_dataframe(arr)
170171

171172

173+
172174
def convert_to_dataframe(array):
173-
indices = list(filter(lambda x: x.startswith('__index__'), array.dtype.names))
175+
176+
def get_nonscalar_columns(array):
177+
first_row = array[0]
178+
bad_cols = np.array([x.ndim != 0 for x in first_row])
179+
col_names = np.array(array.dtype.names)
180+
bad_names = col_names[bad_cols]
181+
if not bad_names.size == 0:
182+
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
183+
.format(bad_names=", ".join(bad_names)), UserWarning)
184+
return list(bad_names)
185+
186+
nonscalar_columns = get_nonscalar_columns(array)
187+
indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names))
174188
if len(indices) == 0:
175-
df = DataFrame.from_records(array)
189+
df = DataFrame.from_records(array, exclude=nonscalar_columns)
176190
elif len(indices) == 1:
177191
# We store the index under the __index__* branch, where
178192
# * is the name of the index
179-
df = DataFrame.from_records(array, index=indices[0])
193+
df = DataFrame.from_records(array, index=indices[0], exclude=nonscalar_columns)
180194
index_name = indices[0][len('__index__'):]
181195
if not index_name:
182196
# None means the index has no name

tests/test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22
from root_pandas import read_root
33
from root_numpy import list_branches
4+
from root_numpy import array2root
45
from pandas.util.testing import assert_frame_equal
56
import numpy as np
67
import ROOT
@@ -137,3 +138,33 @@ def test_flatten():
137138

138139
os.remove('tmp.root')
139140

141+
def test_drop_nonscalar_columns():
142+
array = np.array([1, 2, 3])
143+
matrix = np.array([[1, 2, 3], [4, 5, 6]])
144+
bool_matrix = np.array([[True, False, True], [True, True, True]])
145+
146+
dt = np.dtype([
147+
('a', 'i4'),
148+
('b', 'int64', array.shape),
149+
('c', 'int64', matrix.shape),
150+
('d', 'bool_'),
151+
('e', 'bool_', matrix.shape)
152+
])
153+
arr = np.array([
154+
(3, array, matrix, True, bool_matrix),
155+
(2, array, matrix, False, bool_matrix)],
156+
dtype=dt)
157+
158+
path = 'tmp.root'
159+
array2root(arr, path, 'ntuple', mode='recreate')
160+
161+
df = read_root(path, flatten=False)
162+
# the above line throws an error if flatten=True because nonscalar columns
163+
# are dropped only after the flattening is applied. However, the flattening
164+
# algorithm can not deal with arrays of more than one dimension.
165+
assert(len(df.columns) == 2)
166+
assert(np.all(df.index.values == np.array([0, 1])))
167+
assert(np.all(df.a.values == np.array([3, 2])))
168+
assert(np.all(df.d.values == np.array([True, False])))
169+
170+
os.remove(path)

0 commit comments

Comments
 (0)