5
5
6
6
import numpy as np
7
7
from numpy .lib .recfunctions import append_fields
8
- from pandas import DataFrame
8
+ from pandas import DataFrame , RangeIndex
9
9
from root_numpy import root2array , list_trees
10
10
from fnmatch import fnmatch
11
11
from root_numpy import list_branches
@@ -199,11 +199,13 @@ def do_flatten(arr, flatten):
199
199
# XXX could explicitly clean up the opened TFiles with TChain::Reset
200
200
201
201
def genchunks ():
202
+ current_index = 0
202
203
for chunk in range (int (ceil (float (n_entries ) / chunksize ))):
203
204
arr = root2array (paths , key , all_vars , start = chunk * chunksize , stop = (chunk + 1 ) * chunksize , selection = where , * args , ** kwargs )
204
205
if flatten :
205
206
arr = do_flatten (arr , flatten )
206
- yield convert_to_dataframe (arr )
207
+ yield convert_to_dataframe (arr , start_index = current_index )
208
+ current_index += len (arr )
207
209
return genchunks ()
208
210
209
211
arr = root2array (paths , key , all_vars , selection = where , * args , ** kwargs )
@@ -212,15 +214,17 @@ def genchunks():
212
214
return convert_to_dataframe (arr )
213
215
214
216
215
-
216
- def convert_to_dataframe (array ):
217
+ def convert_to_dataframe (array , start_index = None ):
217
218
nonscalar_columns = get_nonscalar_columns (array )
218
219
if nonscalar_columns :
219
220
warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
220
221
.format (bad_names = ", " .join (nonscalar_columns )), UserWarning )
221
222
indices = list (filter (lambda x : x .startswith ('__index__' ) and x not in nonscalar_columns , array .dtype .names ))
222
223
if len (indices ) == 0 :
223
- df = DataFrame .from_records (array , exclude = nonscalar_columns )
224
+ index = None
225
+ if start_index is not None :
226
+ index = RangeIndex (start = start_index , stop = start_index + len (array ))
227
+ df = DataFrame .from_records (array , exclude = nonscalar_columns , index = index )
224
228
elif len (indices ) == 1 :
225
229
# We store the index under the __index__* branch, where
226
230
# * is the name of the index
@@ -235,7 +239,7 @@ def convert_to_dataframe(array):
235
239
return df
236
240
237
241
238
- def to_root (df , path , key = 'default' , mode = 'w' , * args , ** kwargs ):
242
+ def to_root (df , path , key = 'default' , mode = 'w' , store_index = True , * args , ** kwargs ):
239
243
"""
240
244
Write DataFrame to a ROOT file.
241
245
@@ -247,6 +251,9 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
247
251
Name of tree that the DataFrame will be saved as
248
252
mode: string, {'w', 'a'}
249
253
Mode that the file should be opened in (default: 'w')
254
+ store_index: bool (optional, default: True)
255
+ Whether the index of the DataFrame should be stored as
256
+ an __index__* branch in the tree
250
257
251
258
Notes
252
259
-----
@@ -270,11 +277,12 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
270
277
from root_numpy import array2root
271
278
# We don't want to modify the user's DataFrame here, so we make a shallow copy
272
279
df_ = df .copy (deep = False )
273
- name = df_ .index .name
274
- if name is None :
275
- # Handle the case where the index has no name
276
- name = ''
277
- df_ ['__index__' + name ] = df_ .index
280
+ if store_index :
281
+ name = df_ .index .name
282
+ if name is None :
283
+ # Handle the case where the index has no name
284
+ name = ''
285
+ df_ ['__index__' + name ] = df_ .index
278
286
arr = df_ .to_records (index = False )
279
287
array2root (arr , path , key , mode = mode , * args , ** kwargs )
280
288
0 commit comments