Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Remove depricated use of pandas and ensure column order is correct #83

Merged
merged 5 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions root_pandas/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
A module that extends pandas to support the ROOT data format.
"""
from collections import Counter

import numpy as np
from numpy.lib.recfunctions import append_fields
Expand Down Expand Up @@ -95,11 +96,11 @@ def get_nonscalar_columns(array):
def get_matching_variables(branches, patterns, fail=True):
# Convert branches to a set to make x "in branches" O(1) on average
branches = set(branches)
patterns = set(patterns)
# Find any trivial matches
selected = list(branches.intersection(patterns))
selected = sorted(branches.intersection(patterns),
key=lambda s: patterns.index(s))
# Any matches that weren't trivial need to be looped over...
for pattern in patterns.difference(selected):
for pattern in set(patterns).difference(selected):
found = False
# Avoid using fnmatch if the pattern if possible
if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern):
Expand Down Expand Up @@ -317,7 +318,7 @@ def convert_to_dataframe(array, start_index=None):
# Filter to remove __index__ columns
columns = [c for c in array.dtype.names if c in df.columns]
assert len(columns) == len(df.columns), (columns, df.columns)
df = df.reindex_axis(columns, axis=1, copy=False)
df = df.reindex(columns, axis=1, copy=False)

# Convert categorical columns back to categories
for c in df.columns:
Expand Down Expand Up @@ -366,6 +367,11 @@ def to_root(df, path, key='my_ttree', mode='w', store_index=True, *args, **kwarg
else:
raise ValueError('Unknown mode: {}. Must be "a" or "w".'.format(mode))

column_name_counts = Counter(df.columns)
if max(column_name_counts.values()) > 1:
raise ValueError('DataFrame contains duplicated column names: ' +
' '.join({k for k, v in column_name_counts.items() if v > 1}))

from root_numpy import array2tree
# We don't want to modify the user's DataFrame here, so we make a shallow copy
df_ = df.copy(deep=False)
Expand Down
2 changes: 1 addition & 1 deletion root_pandas/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
'version_info',
]

__version__ = '0.6.1'
__version__ = '0.7.0'
version = __version__
version_info = tuple(__version__.split('.'))
17 changes: 17 additions & 0 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,20 @@ def test_issue_63():
assert all(len(df) == 1 for df in result)
os.remove('tmp_1.root')
os.remove('tmp_2.root')


def test_issue_80():
df = pd.DataFrame({'a': [1, 2], 'b': [4, 5]})
df.columns = ['a', 'a']
try:
root_pandas.to_root(df, '/tmp/example.root')
except ValueError as e:
assert 'DataFrame contains duplicated column names' in e.args[0]
else:
raise Exception('ValueError is expected')


def test_issue_82():
variables = ['MET_px', 'MET_py', 'EventWeight']
df = root_pandas.read_root('http://scikit-hep.org/uproot/examples/HZZ.root', 'events', columns=variables)
assert list(df.columns) == variables