Skip to content

Commit 76f3a7b

Browse files
aseyboldtmichaelosthegeAlexAndorra
authored
Allow specification of dims instead of shape (#3551)
* Allow specification of dims instead of shape * Add pm.TidyData * Create coords for pm.Data(ndarray) * empty commit to trigger CI * Apply suggestions from code review Co-authored-by: Alexandre ANDORRA <[email protected]> * apply black formatting * address review comments & formatting * Add demonstration of named coordinates/dims * don't require dim names to be identifiers * sort imports * raise ShapeError instead of ValueError * formatting * robustify Dtype and ShapeError * Removed TidyData and refined dims and coords implementation * Changed name of kwarg export_dims and improved docstrings * Add link to ArviZ in docstrings * Removed TidyData from __all__ * Polished Data container NB * Fixed line break in data.py * Fix inference of coords for dataframes * Refined Data container NB * Updated getting started NB with new dims and coords features * Reran getting started NB * Blackified NBs * rerun with ArviZ branch * use np.shape to be compatible with tuples/lists * add tests for named coordinate handling * Extended tests for data container Co-authored-by: Michael Osthege <[email protected]> Co-authored-by: Michael Osthege <[email protected]> Co-authored-by: Alexandre ANDORRA <[email protected]>
1 parent 8a8beab commit 76f3a7b

File tree

9 files changed

+2375
-728
lines changed

9 files changed

+2375
-728
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
- `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)).
1818
- `pm.Data` container can now be used for index variables, i.e with integer data and not only floats (issue [#3813](https://github.com/pymc-devs/pymc3/issues/3813), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
1919
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
20+
- Allow users to specify coordinates and dimension names instead of numerical shapes when specifying a model. This makes interoperability with ArviZ easier. ([see #3551](https://github.com/pymc-devs/pymc3/pull/3551))
2021
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
2122
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
2223
- SMC-ABC: add option to define summary statistics, allow to sample from more complex models, remove redundant distances [#3940](https://github.com/pymc-devs/pymc3/issues/3940)

docs/source/notebooks/data_container.ipynb

Lines changed: 1064 additions & 207 deletions
Large diffs are not rendered by default.

docs/source/notebooks/getting_started.ipynb

Lines changed: 669 additions & 231 deletions
Large diffs are not rendered by default.

pymc3/data.py

Lines changed: 142 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, Any
15+
import collections
1616
from copy import copy
1717
import io
1818
import os
1919
import pkgutil
20-
import collections
20+
from typing import Dict, List, Any
21+
2122
import numpy as np
23+
import pandas as pd
2224
import pymc3 as pm
2325
import theano.tensor as tt
2426
import theano
2527

2628
__all__ = [
27-
'get_data',
28-
'GeneratorAdapter',
29-
'Minibatch',
30-
'align_minibatches',
31-
'Data',
29+
"get_data",
30+
"GeneratorAdapter",
31+
"Minibatch",
32+
"align_minibatches",
33+
"Data",
3234
]
3335

3436

@@ -44,8 +46,8 @@ def get_data(filename):
4446
-------
4547
BytesIO of the data
4648
"""
47-
data_pkg = 'pymc3.examples'
48-
return io.BytesIO(pkgutil.get_data(data_pkg, os.path.join('data', filename)))
49+
data_pkg = "pymc3.examples"
50+
return io.BytesIO(pkgutil.get_data(data_pkg, os.path.join("data", filename)))
4951

5052

5153
class GenTensorVariable(tt.TensorVariable):
@@ -78,14 +80,14 @@ def make_variable(self, gop, name=None):
7880

7981
def __init__(self, generator):
8082
if not pm.vartypes.isgenerator(generator):
81-
raise TypeError('Object should be generator like')
83+
raise TypeError("Object should be generator like")
8284
self.test_value = pm.smartfloatX(copy(next(generator)))
8385
# make pickling potentially possible
8486
self._yielded_test_value = False
8587
self.gen = generator
8688
self.tensortype = tt.TensorType(
87-
self.test_value.dtype,
88-
((False, ) * self.test_value.ndim))
89+
self.test_value.dtype, ((False,) * self.test_value.ndim)
90+
)
8991

9092
# python3 generator
9193
def __next__(self):
@@ -283,28 +285,37 @@ class Minibatch(tt.TensorVariable):
283285
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
284286
"""
285287

286-
RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]
287-
288-
@theano.configparser.change_flags(compute_test_value='raise')
289-
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',
290-
random_seed=42, update_shared_f=None, in_memory_size=None):
288+
RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]
289+
290+
@theano.configparser.change_flags(compute_test_value="raise")
291+
def __init__(
292+
self,
293+
data,
294+
batch_size=128,
295+
dtype=None,
296+
broadcastable=None,
297+
name="Minibatch",
298+
random_seed=42,
299+
update_shared_f=None,
300+
in_memory_size=None,
301+
):
291302
if dtype is None:
292303
data = pm.smartfloatX(np.asarray(data))
293304
else:
294305
data = np.asarray(data, dtype)
295306
in_memory_slc = self.make_static_slices(in_memory_size)
296307
self.shared = theano.shared(data[in_memory_slc])
297308
self.update_shared_f = update_shared_f
298-
self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed)
309+
self.random_slc = self.make_random_slices(
310+
self.shared.shape, batch_size, random_seed
311+
)
299312
minibatch = self.shared[self.random_slc]
300313
if broadcastable is None:
301-
broadcastable = (False, ) * minibatch.ndim
314+
broadcastable = (False,) * minibatch.ndim
302315
minibatch = tt.patternbroadcast(minibatch, broadcastable)
303316
self.minibatch = minibatch
304317
super().__init__(self.minibatch.type, None, None, name=name)
305-
theano.Apply(
306-
theano.compile.view_op,
307-
inputs=[self.minibatch], outputs=[self])
318+
theano.Apply(theano.compile.view_op, inputs=[self.minibatch], outputs=[self])
308319
self.tag.test_value = copy(self.minibatch.tag.test_value)
309320

310321
def rslice(self, total, size, seed):
@@ -313,11 +324,11 @@ def rslice(self, total, size, seed):
313324
elif isinstance(size, int):
314325
rng = pm.tt_rng(seed)
315326
Minibatch.RNG[id(self)].append(rng)
316-
return (rng
317-
.uniform(size=(size, ), low=0.0, high=pm.floatX(total) - 1e-16)
318-
.astype('int64'))
327+
return rng.uniform(
328+
size=(size,), low=0.0, high=pm.floatX(total) - 1e-16
329+
).astype("int64")
319330
else:
320-
raise TypeError('Unrecognized size type, %r' % size)
331+
raise TypeError("Unrecognized size type, %r" % size)
321332

322333
def __del__(self):
323334
del Minibatch.RNG[id(self)]
@@ -340,17 +351,18 @@ def make_static_slices(user_size):
340351
elif isinstance(i, slice):
341352
slc.append(i)
342353
else:
343-
raise TypeError('Unrecognized size type, %r' % user_size)
354+
raise TypeError("Unrecognized size type, %r" % user_size)
344355
return slc
345356
else:
346-
raise TypeError('Unrecognized size type, %r' % user_size)
357+
raise TypeError("Unrecognized size type, %r" % user_size)
347358

348359
def make_random_slices(self, in_memory_shape, batch_size, default_random_seed):
349360
if batch_size is None:
350361
return [Ellipsis]
351362
elif isinstance(batch_size, int):
352363
slc = [self.rslice(in_memory_shape[0], batch_size, default_random_seed)]
353364
elif isinstance(batch_size, (list, tuple)):
365+
354366
def check(t):
355367
if t is Ellipsis or t is None:
356368
return True
@@ -364,12 +376,14 @@ def check(t):
364376
return True
365377
else:
366378
return False
379+
367380
# end check definition
368381
if not all(check(t) for t in batch_size):
369-
raise TypeError('Unrecognized `batch_size` type, expected '
370-
'int or List[int|tuple(size, random_seed)] where '
371-
'size and random seed are both ints, got %r' %
372-
batch_size)
382+
raise TypeError(
383+
"Unrecognized `batch_size` type, expected "
384+
"int or List[int|tuple(size, random_seed)] where "
385+
"size and random seed are both ints, got %r" % batch_size
386+
)
373387
batch_size = [
374388
(i, default_random_seed) if isinstance(i, int) else i
375389
for i in batch_size
@@ -378,12 +392,14 @@ def check(t):
378392
if Ellipsis in batch_size:
379393
sep = batch_size.index(Ellipsis)
380394
begin = batch_size[:sep]
381-
end = batch_size[sep + 1:]
395+
end = batch_size[sep + 1 :]
382396
if Ellipsis in end:
383-
raise ValueError('Double Ellipsis in `batch_size` is restricted, got %r' %
384-
batch_size)
397+
raise ValueError(
398+
"Double Ellipsis in `batch_size` is restricted, got %r"
399+
% batch_size
400+
)
385401
if len(end) > 0:
386-
shp_mid = shape[sep:-len(end)]
402+
shp_mid = shape[sep : -len(end)]
387403
mid = [tt.arange(s) for s in shp_mid]
388404
else:
389405
mid = []
@@ -392,23 +408,30 @@ def check(t):
392408
end = []
393409
mid = []
394410
if (len(begin) + len(end)) > len(in_memory_shape.eval()):
395-
raise ValueError('Length of `batch_size` is too big, '
396-
'number of ints is bigger that ndim, got %r'
397-
% batch_size)
411+
raise ValueError(
412+
"Length of `batch_size` is too big, "
413+
"number of ints is bigger that ndim, got %r" % batch_size
414+
)
398415
if len(end) > 0:
399-
shp_end = shape[-len(end):]
416+
shp_end = shape[-len(end) :]
400417
else:
401418
shp_end = np.asarray([])
402-
shp_begin = shape[:len(begin)]
403-
slc_begin = [self.rslice(shp_begin[i], t[0], t[1])
404-
if t is not None else tt.arange(shp_begin[i])
405-
for i, t in enumerate(begin)]
406-
slc_end = [self.rslice(shp_end[i], t[0], t[1])
407-
if t is not None else tt.arange(shp_end[i])
408-
for i, t in enumerate(end)]
419+
shp_begin = shape[: len(begin)]
420+
slc_begin = [
421+
self.rslice(shp_begin[i], t[0], t[1])
422+
if t is not None
423+
else tt.arange(shp_begin[i])
424+
for i, t in enumerate(begin)
425+
]
426+
slc_end = [
427+
self.rslice(shp_end[i], t[0], t[1])
428+
if t is not None
429+
else tt.arange(shp_end[i])
430+
for i, t in enumerate(end)
431+
]
409432
slc = slc_begin + mid + slc_end
410433
else:
411-
raise TypeError('Unrecognized size type, %r' % batch_size)
434+
raise TypeError("Unrecognized size type, %r" % batch_size)
412435
return pm.theanof.ix_(*slc)
413436

414437
def update_shared(self):
@@ -434,7 +457,7 @@ def align_minibatches(batches=None):
434457
else:
435458
for b in batches:
436459
if not isinstance(b, Minibatch):
437-
raise TypeError('{b} is not a Minibatch')
460+
raise TypeError("{b} is not a Minibatch")
438461
for rng in Minibatch.RNG[id(b)]:
439462
rng.seed()
440463

@@ -447,8 +470,17 @@ class Data:
447470
----------
448471
name: str
449472
The name for this variable
450-
value
473+
value: {List, np.ndarray, pd.Series, pd.Dataframe}
451474
A value to associate with this variable
475+
dims: {str, tuple of str}, optional, default=None
476+
Dimension names of the random variables (as opposed to the shapes of these
477+
random variables). Use this when `value` is a Pandas Series or DataFrame. The
478+
`dims` will then be the name of the Series / DataFrame's columns. See ArviZ
479+
documentation for more information about dimensions and coordinates:
480+
https://arviz-devs.github.io/arviz/notebooks/Introduction.html
481+
export_index_as_coords: bool, optional, default=False
482+
If True, the `Data` container will try to infer what the coordinates should be
483+
if there is an index in `value`.
452484
453485
Examples
454486
--------
@@ -479,7 +511,7 @@ class Data:
479511
https://docs.pymc.io/notebooks/data_container.html
480512
"""
481513

482-
def __new__(self, name, value):
514+
def __new__(self, name, value, *, dims=None, export_index_as_coords=False):
483515
if isinstance(value, list):
484516
value = np.array(value)
485517

@@ -497,10 +529,68 @@ def __new__(self, name, value):
497529
# transforms it to something digestible for pymc3
498530
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
499531

532+
if isinstance(dims, str):
533+
dims = (dims,)
534+
if not (dims is None or len(dims) == shared_object.ndim):
535+
raise pm.exceptions.ShapeError(
536+
"Length of `dims` must match the dimensions of the dataset.",
537+
actual=len(dims), expected=shared_object.ndim
538+
)
539+
540+
coords = self.set_coords(model, value, dims)
541+
542+
if export_index_as_coords:
543+
model.add_coords(coords)
544+
500545
# To draw the node for this variable in the graphviz Digraph we need
501546
# its shape.
502547
shared_object.dshape = tuple(shared_object.shape.eval())
548+
if dims is not None:
549+
shape_dims = model.shape_from_dims(dims)
550+
if shared_object.dshape != shape_dims:
551+
raise pm.exceptions.ShapeError(
552+
"Data shape does not match with specified `dims`.",
553+
actual=shared_object.dshape, expected=shape_dims
554+
)
503555

504-
model.add_random_variable(shared_object)
556+
model.add_random_variable(shared_object, dims=dims)
505557

506558
return shared_object
559+
560+
@staticmethod
561+
def set_coords(model, value, dims=None):
562+
coords = {}
563+
564+
# If value is a df or a series, we interpret the index as coords:
565+
if isinstance(value, (pd.Series, pd.DataFrame)):
566+
dim_name = None
567+
if dims is not None:
568+
dim_name = dims[0]
569+
if dim_name is None and value.index.name is not None:
570+
dim_name = value.index.name
571+
if dim_name is not None:
572+
coords[dim_name] = value.index
573+
574+
# If value is a df, we also interpret the columns as coords:
575+
if isinstance(value, pd.DataFrame):
576+
dim_name = None
577+
if dims is not None:
578+
dim_name = dims[1]
579+
if dim_name is None and value.columns.name is not None:
580+
dim_name = value.columns.name
581+
if dim_name is not None:
582+
coords[dim_name] = value.columns
583+
584+
if isinstance(value, np.ndarray) and dims is not None:
585+
if len(dims) != value.ndim:
586+
raise pm.exceptions.ShapeError(
587+
"Invalid data shape. The rank of the dataset must match the "
588+
"length of `dims`.",
589+
actual=value.shape, expected=value.ndim
590+
)
591+
for size, dim in zip(value.shape, dims):
592+
coord = model.coords.get(dim, None)
593+
if coord is None:
594+
coords[dim] = pd.RangeIndex(size, name=dim)
595+
596+
return coords

0 commit comments

Comments
 (0)