Skip to content

Commit db789ec

Browse files
author
Larry Dong
committed
Merge remote-tracking branch 'upstream/v4'
2 parents 903eaa2 + 7adf05d commit db789ec

19 files changed

+360
-269
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
--ignore=pymc3/tests/test_shape_handling.py
6363
--ignore=pymc3/tests/test_distributions.py
6464
--ignore=pymc3/tests/test_distributions_random.py
65+
--ignore=pymc3/tests/test_idata_conversion.py
6566
6667
- |
6768
pymc3/tests/test_modelcontext.py
@@ -73,6 +74,7 @@ jobs:
7374
pymc3/tests/test_updates.py
7475
7576
- |
77+
pymc3/tests/test_idata_conversion.py
7678
pymc3/tests/test_distributions.py
7779
pymc3/tests/test_distributions_random.py
7880
pymc3/tests/test_examples.py

pymc3/aesaraf.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from typing import (
1517
Callable,
1618
Dict,
@@ -169,19 +171,17 @@ def change_rv_size(
169171
def extract_rv_and_value_vars(
170172
var: TensorVariable,
171173
) -> Tuple[TensorVariable, TensorVariable]:
172-
"""Extract a random variable and its corresponding value variable from a generic
173-
`TensorVariable`.
174+
"""Return a random variable and it's observations or value variable, or ``None``.
174175
175176
Parameters
176177
==========
177178
var
178-
A variable corresponding to a `RandomVariable`.
179+
A variable corresponding to a ``RandomVariable``.
179180
180181
Returns
181182
=======
182-
The first value in the tuple is the `RandomVariable`, and the second is the
183-
measure-space variable that corresponds with the latter (i.e. the "value"
184-
variable).
183+
The first value in the tuple is the ``RandomVariable``, and the second is the
184+
measure/log-likelihood value variable that corresponds with the latter.
185185
186186
"""
187187
if not var.owner:
@@ -195,7 +195,7 @@ def extract_rv_and_value_vars(
195195

196196

197197
def extract_obs_data(x: TensorVariable) -> np.ndarray:
198-
"""Extract data observed symbolic variables.
198+
"""Extract data from observed symbolic variables.
199199
200200
Raises
201201
------
@@ -331,17 +331,24 @@ def transform_replacements(var, replacements):
331331
rv_var, rv_value_var = extract_rv_and_value_vars(var)
332332

333333
if rv_value_var is None:
334+
warnings.warn(
335+
f"No value variable found for {rv_var}; "
336+
"the random variable will not be replaced."
337+
)
334338
return []
335339

336340
transform = getattr(rv_value_var.tag, "transform", None)
337341

338342
if transform is None or not apply_transforms:
339343
replacements[var] = rv_value_var
340-
return []
344+
# In case the value variable is itself a graph, we walk it for
345+
# potential replacements
346+
return [rv_value_var]
341347

342348
trans_rv_value = transform.backward(rv_var, rv_value_var)
343349
replacements[var] = trans_rv_value
344350

351+
# Walk the transformed variable and make replacements
345352
return [trans_rv_value]
346353

347354
return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)

pymc3/blocking.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
"""
2020
import collections
2121

22-
from typing import Dict, List, Optional, Union
22+
from functools import partial
23+
from typing import Callable, Dict, Optional, TypeVar
2324

2425
import numpy as np
2526

2627
__all__ = ["DictToArrayBijection"]
2728

29+
30+
T = TypeVar("T")
31+
PointType = Dict[str, np.ndarray]
32+
2833
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
2934
# each of the raveled variables.
3035
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
@@ -38,7 +43,7 @@ class DictToArrayBijection:
3843
"""
3944

4045
@staticmethod
41-
def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
46+
def map(var_dict: PointType) -> RaveledVars:
4247
"""Map a dictionary of names and variables to a concatenated 1D array space."""
4348
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
4449
raveled_vars = [v[0].ravel() for v in vars_info]
@@ -50,42 +55,41 @@ def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
5055

5156
@staticmethod
5257
def rmap(
53-
array: RaveledVars, as_list: Optional[bool] = False
54-
) -> Union[Dict[str, np.ndarray], List[np.ndarray]]:
58+
array: RaveledVars,
59+
start_point: Optional[PointType] = None,
60+
) -> PointType:
5561
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
5662
5763
Parameters
5864
==========
5965
array
6066
The array to map.
61-
as_list
62-
When ``True``, return a list of the original variables instead of a
63-
``dict`` keyed each variable's name.
67+
start_point
68+
An optional dictionary of initial values.
69+
6470
"""
65-
if as_list:
66-
res = []
71+
if start_point:
72+
res = dict(start_point)
6773
else:
6874
res = {}
6975

7076
if not isinstance(array, RaveledVars):
71-
raise TypeError("`apt` must be a `RaveledVars` type")
77+
raise TypeError("`array` must be a `RaveledVars` type")
7278

7379
last_idx = 0
7480
for name, shape, dtype in array.point_map_info:
7581
arr_len = np.prod(shape, dtype=int)
7682
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
77-
if as_list:
78-
res.append(var)
79-
else:
80-
res[name] = var
83+
res[name] = var
8184
last_idx += arr_len
8285

8386
return res
8487

8588
@classmethod
86-
def mapf(cls, f):
87-
"""
88-
function f: DictSpace -> T to ArraySpace -> T
89+
def mapf(cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None) -> T:
90+
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
91+
92+
function f: DictSpace -> T to ArraySpace -> T
8993
9094
Parameters
9195
----------
@@ -95,7 +99,7 @@ def mapf(cls, f):
9599
-------
96100
f: array -> T
97101
"""
98-
return Compose(f, cls.rmap)
102+
return Compose(f, partial(cls.rmap, start_point=start_point))
99103

100104

101105
class Compose:

pymc3/distributions/discrete.py

Lines changed: 32 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
import aesara.tensor as at
1717
import numpy as np
1818

19-
from aesara.tensor.random.basic import bernoulli, binomial, categorical, nbinom, poisson
19+
from aesara.tensor.random.basic import (
20+
RandomVariable,
21+
bernoulli,
22+
binomial,
23+
categorical,
24+
nbinom,
25+
poisson,
26+
)
2027
from scipy import stats
2128

2229
from pymc3.aesaraf import floatX, intX, take_along_axis
@@ -434,6 +441,22 @@ def _distr_parameters_for_repr(self):
434441
return ["p"]
435442

436443

444+
class DiscreteWeibullRV(RandomVariable):
445+
name = "discrete_weibull"
446+
ndim_supp = 0
447+
ndims_params = [0, 0]
448+
dtype = "int64"
449+
_print_name = ("dWeibull", "\\operatorname{dWeibull}")
450+
451+
@classmethod
452+
def rng_fn(cls, rng, q, beta, size):
453+
p = rng.uniform(size=size)
454+
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
455+
456+
457+
discrete_weibull = DiscreteWeibullRV()
458+
459+
437460
class DiscreteWeibull(Discrete):
438461
R"""Discrete Weibull log-likelihood
439462
@@ -473,51 +496,15 @@ def DiscreteWeibull(q, b, x):
473496
Variance :math:`2 \sum_{x = 1}^{\infty} x q^{x^{\beta}} - \mu - \mu^2`
474497
======== ======================
475498
"""
499+
rv_op = discrete_weibull
476500

477-
def __init__(self, q, beta, *args, **kwargs):
478-
super().__init__(*args, defaults=("median",), **kwargs)
479-
480-
self.q = at.as_tensor_variable(floatX(q))
481-
self.beta = at.as_tensor_variable(floatX(beta))
482-
483-
self.median = self._ppf(0.5)
484-
485-
def _ppf(self, p):
486-
r"""
487-
The percentile point function (the inverse of the cumulative
488-
distribution function) of the discrete Weibull distribution.
489-
"""
490-
q = self.q
491-
beta = self.beta
492-
493-
return (at.ceil(at.power(at.log(1 - p) / at.log(q), 1.0 / beta)) - 1).astype("int64")
494-
495-
def _random(self, q, beta, size=None):
496-
p = np.random.uniform(size=size)
497-
498-
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
499-
500-
def random(self, point=None, size=None):
501-
r"""
502-
Draw random values from DiscreteWeibull distribution.
503-
504-
Parameters
505-
----------
506-
point: dict, optional
507-
Dict of variable values on which random values are to be
508-
conditioned (uses default point if not specified).
509-
size: int, optional
510-
Desired size of random sample (returns one sample if not
511-
specified).
512-
513-
Returns
514-
-------
515-
array
516-
"""
517-
# q, beta = draw_values([self.q, self.beta], point=point, size=size)
518-
# return generate_samples(self._random, q, beta, dist_shape=self.shape, size=size)
501+
@classmethod
502+
def dist(cls, q, beta, *args, **kwargs):
503+
q = at.as_tensor_variable(floatX(q))
504+
beta = at.as_tensor_variable(floatX(beta))
505+
return super().dist([q, beta], **kwargs)
519506

520-
def logp(self, value):
507+
def logp(value, q, beta):
521508
r"""
522509
Calculate log-probability of DiscreteWeibull distribution at specified value.
523510
@@ -531,8 +518,6 @@ def logp(self, value):
531518
-------
532519
TensorVariable
533520
"""
534-
q = self.q
535-
beta = self.beta
536521
return bound(
537522
at.log(at.power(q, at.power(value, beta)) - at.power(q, at.power(value + 1, beta))),
538523
0 <= value,
@@ -541,7 +526,7 @@ def logp(self, value):
541526
0 < beta,
542527
)
543528

544-
def logcdf(self, value):
529+
def logcdf(value, q, beta):
545530
"""
546531
Compute the log of the cumulative distribution function for Discrete Weibull distribution
547532
at the specified value.
@@ -556,9 +541,6 @@ def logcdf(self, value):
556541
-------
557542
TensorVariable
558543
"""
559-
q = self.q
560-
beta = self.beta
561-
562544
return bound(
563545
at.log1p(-at.power(q, at.power(value + 1, beta))),
564546
0 <= value,

0 commit comments

Comments
 (0)