|
| 1 | +# Copyright (c) 2012 rootpy developers and contributors |
| 2 | +# |
| 3 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of |
| 4 | +# this software and associated documentation files (the "Software"), to deal in |
| 5 | +# the Software without restriction, including without limitation the rights to |
| 6 | +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of |
| 7 | +# the Software, and to permit persons to whom the Software is furnished to do so, |
| 8 | +# subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all |
| 11 | +# copies or substantial portions of the Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 14 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
| 15 | +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR |
| 16 | +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, |
| 17 | +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
| 18 | +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| 19 | +# |
| 20 | +# |
| 21 | +# Code temporarily copied from the root_numpy package |
| 22 | +# |
| 23 | + |
| 24 | +import numpy as np |
| 25 | +VLEN = np.vectorize(len) |
| 26 | + |
| 27 | +def stretch(arr, fields=None, return_indices=False): |
| 28 | + """Stretch an array. |
| 29 | + Stretch an array by ``hstack()``-ing multiple array fields while |
| 30 | + preserving column names and record array structure. If a scalar field is |
| 31 | + specified, it will be stretched along with array fields. |
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + arr : NumPy structured or record array |
| 35 | + The array to be stretched. |
| 36 | + fields : list of strings, optional (default=None) |
| 37 | + A list of column names to stretch. If None, then stretch all fields. |
| 38 | + return_indices : bool, optional (default=False) |
| 39 | + If True, the array index of each stretched array entry will be |
| 40 | + returned in addition to the stretched array. |
| 41 | + This changes the return type of this function to a tuple consisting |
| 42 | + of a structured array and a numpy int64 array. |
| 43 | + Returns |
| 44 | + ------- |
| 45 | + ret : A NumPy structured array |
| 46 | + The stretched array. |
| 47 | + Examples |
| 48 | + -------- |
| 49 | + >>> import numpy as np |
| 50 | + >>> from root_numpy import stretch |
| 51 | + >>> arr = np.empty(2, dtype=[('scalar', np.int), ('array', 'O')]) |
| 52 | + >>> arr[0] = (0, np.array([1, 2, 3], dtype=np.float)) |
| 53 | + >>> arr[1] = (1, np.array([4, 5, 6], dtype=np.float)) |
| 54 | + >>> stretch(arr, ['scalar', 'array']) |
| 55 | + array([(0, 1.0), (0, 2.0), (0, 3.0), (1, 4.0), (1, 5.0), (1, 6.0)], |
| 56 | + dtype=[('scalar', '<i8'), ('array', '<f8')]) |
| 57 | + """ |
| 58 | + dtype = [] |
| 59 | + len_array = None |
| 60 | + |
| 61 | + if fields is None: |
| 62 | + fields = arr.dtype.names |
| 63 | + |
| 64 | + # Construct dtype and check consistency |
| 65 | + for field in fields: |
| 66 | + dt = arr.dtype[field] |
| 67 | + if dt == 'O' or len(dt.shape): |
| 68 | + if dt == 'O': |
| 69 | + # Variable-length array field |
| 70 | + lengths = VLEN(arr[field]) |
| 71 | + else: |
| 72 | + lengths = np.repeat(dt.shape[0], arr.shape[0]) |
| 73 | + # Fixed-length array field |
| 74 | + if len_array is None: |
| 75 | + len_array = lengths |
| 76 | + elif not np.array_equal(lengths, len_array): |
| 77 | + raise ValueError( |
| 78 | + "inconsistent lengths of array columns in input") |
| 79 | + if dt == 'O': |
| 80 | + dtype.append((field, arr[field][0].dtype)) |
| 81 | + else: |
| 82 | + dtype.append((field, arr[field].dtype, dt.shape[1:])) |
| 83 | + else: |
| 84 | + # Scalar field |
| 85 | + dtype.append((field, dt)) |
| 86 | + |
| 87 | + if len_array is None: |
| 88 | + raise RuntimeError("no array column in input") |
| 89 | + |
| 90 | + # Build stretched output |
| 91 | + ret = np.empty(np.sum(len_array), dtype=dtype) |
| 92 | + for field in fields: |
| 93 | + dt = arr.dtype[field] |
| 94 | + if dt == 'O' or len(dt.shape) == 1: |
| 95 | + # Variable-length or 1D fixed-length array field |
| 96 | + ret[field] = np.hstack(arr[field]) |
| 97 | + elif len(dt.shape): |
| 98 | + # Multidimensional fixed-length array field |
| 99 | + ret[field] = np.vstack(arr[field]) |
| 100 | + else: |
| 101 | + # Scalar field |
| 102 | + ret[field] = np.repeat(arr[field], len_array) |
| 103 | + |
| 104 | + if return_indices: |
| 105 | + idx = np.concatenate(list(map(np.arange, len_array))) |
| 106 | + return ret, idx |
| 107 | + |
| 108 | + return ret |
0 commit comments