|
25 | 25 | import scipy
|
26 | 26 |
|
27 | 27 | from aesara.assert_op import Assert
|
28 |
| -from aesara.graph.basic import Apply |
| 28 | +from aesara.graph.basic import Apply, Constant |
29 | 29 | from aesara.graph.op import Op
|
30 | 30 | from aesara.sparse.basic import sp_sum
|
31 | 31 | from aesara.tensor import gammaln, sigmoid
|
|
43 | 43 |
|
44 | 44 | from pymc.aesaraf import floatX, intX
|
45 | 45 | from pymc.distributions import transforms
|
46 |
| -from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support |
| 46 | +from pymc.distributions.continuous import ( |
| 47 | + BoundedContinuous, |
| 48 | + ChiSquared, |
| 49 | + Normal, |
| 50 | + assert_negative_support, |
| 51 | +) |
47 | 52 | from pymc.distributions.dist_math import (
|
48 | 53 | betaln,
|
49 | 54 | check_parameters,
|
|
57 | 62 | rv_size_is_none,
|
58 | 63 | to_tuple,
|
59 | 64 | )
|
| 65 | +from pymc.distributions.transforms import interval |
60 | 66 | from pymc.math import kron_diag, kron_dot
|
| 67 | +from pymc.util import UNSET |
61 | 68 |
|
62 | 69 | __all__ = [
|
63 | 70 | "MvNormal",
|
@@ -1079,6 +1086,11 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
|
1079 | 1086 |
|
1080 | 1087 |
|
1081 | 1088 | def _lkj_normalizing_constant(eta, n):
|
| 1089 | + # TODO: This is mixing python branching with the potentially symbolic n and eta variables |
| 1090 | + if not isinstance(eta, (int, float)): |
| 1091 | + raise NotImplementedError("eta must be an int or float") |
| 1092 | + if not isinstance(n, int): |
| 1093 | + raise NotImplementedError("n must be an integer") |
1082 | 1094 | if eta == 1:
|
1083 | 1095 | result = gammaln(2.0 * at.arange(1, int((n - 1) / 2) + 1)).sum()
|
1084 | 1096 | if n % 2 == 1:
|
@@ -1431,7 +1443,74 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
|
1431 | 1443 | return chol, corr, stds
|
1432 | 1444 |
|
1433 | 1445 |
|
1434 |
| -class LKJCorr(Continuous): |
| 1446 | +class LKJCorrRV(RandomVariable): |
| 1447 | + name = "lkjcorr" |
| 1448 | + ndim_supp = 1 |
| 1449 | + ndims_params = [0, 0] |
| 1450 | + dtype = "floatX" |
| 1451 | + _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") |
| 1452 | + |
| 1453 | + def make_node(self, rng, size, dtype, n, eta): |
| 1454 | + n = at.as_tensor_variable(n) |
| 1455 | + if not n.ndim == 0: |
| 1456 | + raise ValueError("n must be a scalar (ndim=0).") |
| 1457 | + |
| 1458 | + eta = at.as_tensor_variable(eta) |
| 1459 | + if not eta.ndim == 0: |
| 1460 | + raise ValueError("eta must be a scalar (ndim=0).") |
| 1461 | + |
| 1462 | + return super().make_node(rng, size, dtype, n, eta) |
| 1463 | + |
| 1464 | + def _shape_from_params(self, dist_params, **kwargs): |
| 1465 | + n = dist_params[0] |
| 1466 | + dist_shape = ((n * (n - 1)) // 2,) |
| 1467 | + return dist_shape |
| 1468 | + |
| 1469 | + @classmethod |
| 1470 | + def rng_fn(cls, rng, n, eta, size): |
| 1471 | + |
| 1472 | + # We flatten the size to make operations easier, and then rebuild it |
| 1473 | + if size is None: |
| 1474 | + flat_size = 1 |
| 1475 | + else: |
| 1476 | + flat_size = np.prod(size) |
| 1477 | + |
| 1478 | + C = cls._random_corr_matrix(rng, n, eta, flat_size) |
| 1479 | + |
| 1480 | + triu_idx = np.triu_indices(n, k=1) |
| 1481 | + samples = C[..., triu_idx[0], triu_idx[1]] |
| 1482 | + |
| 1483 | + if size is None: |
| 1484 | + samples = samples[0] |
| 1485 | + else: |
| 1486 | + dist_shape = (n * (n - 1)) // 2 |
| 1487 | + samples = np.reshape(samples, (*size, dist_shape)) |
| 1488 | + return samples |
| 1489 | + |
| 1490 | + @classmethod |
| 1491 | + def _random_corr_matrix(cls, rng, n, eta, flat_size): |
| 1492 | + # original implementation in R see: |
| 1493 | + # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r |
| 1494 | + beta = eta - 1.0 + n / 2.0 |
| 1495 | + r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0 |
| 1496 | + P = np.full((flat_size, n, n), np.eye(n)) |
| 1497 | + P[..., 0, 1] = r12 |
| 1498 | + P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2) |
| 1499 | + for mp1 in range(2, n): |
| 1500 | + beta -= 0.5 |
| 1501 | + y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=flat_size, random_state=rng) |
| 1502 | + z = stats.norm.rvs(loc=0, scale=1, size=(flat_size, mp1), random_state=rng) |
| 1503 | + z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis] |
| 1504 | + P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z |
| 1505 | + P[..., mp1, mp1] = np.sqrt(1.0 - y) |
| 1506 | + C = np.einsum("...ji,...jk->...ik", P, P) |
| 1507 | + return C |
| 1508 | + |
| 1509 | + |
| 1510 | +lkjcorr = LKJCorrRV() |
| 1511 | + |
| 1512 | + |
| 1513 | +class LKJCorr(BoundedContinuous): |
1435 | 1514 | r"""
|
1436 | 1515 | The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
|
1437 | 1516 |
|
@@ -1473,112 +1552,60 @@ class LKJCorr(Continuous):
|
1473 | 1552 | 100(9), pp.1989-2001.
|
1474 | 1553 | """
|
1475 | 1554 |
|
1476 |
| - def __init__(self, eta=None, n=None, p=None, transform="interval", *args, **kwargs): |
1477 |
| - if (p is not None) and (n is not None) and (eta is None): |
1478 |
| - warnings.warn( |
1479 |
| - "Parameters to LKJCorr have changed: shape parameter n -> eta " |
1480 |
| - "dimension parameter p -> n. Please update your code. " |
1481 |
| - "Automatically re-assigning parameters for backwards compatibility.", |
1482 |
| - FutureWarning, |
1483 |
| - ) |
1484 |
| - self.n = p |
1485 |
| - self.eta = n |
1486 |
| - eta = self.eta |
1487 |
| - n = self.n |
1488 |
| - elif (n is not None) and (eta is not None) and (p is None): |
1489 |
| - self.n = n |
1490 |
| - self.eta = eta |
1491 |
| - else: |
1492 |
| - raise ValueError( |
1493 |
| - "Invalid parameter: please use eta as the shape parameter and " |
1494 |
| - "n as the dimension parameter." |
1495 |
| - ) |
1496 |
| - |
1497 |
| - shape = n * (n - 1) // 2 |
1498 |
| - self.mean = floatX(np.zeros(shape)) |
1499 |
| - |
1500 |
| - if transform == "interval": |
1501 |
| - transform = transforms.interval(-1, 1) |
1502 |
| - |
1503 |
| - super().__init__(shape=shape, transform=transform, *args, **kwargs) |
1504 |
| - warnings.warn( |
1505 |
| - "Parameters in LKJCorr have been rename: shape parameter n -> eta " |
1506 |
| - "dimension parameter p -> n. Please double check your initialization.", |
1507 |
| - FutureWarning, |
1508 |
| - ) |
1509 |
| - self.tri_index = np.zeros([n, n], dtype="int32") |
1510 |
| - self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape) |
1511 |
| - self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) |
1512 |
| - |
1513 |
| - def _random(self, n, eta, size=None): |
1514 |
| - size = size if isinstance(size, tuple) else (size,) |
1515 |
| - # original implementation in R see: |
1516 |
| - # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r |
1517 |
| - beta = eta - 1.0 + n / 2.0 |
1518 |
| - r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0 |
1519 |
| - P = np.eye(n)[:, :, np.newaxis] * np.ones(size) |
1520 |
| - P[0, 1] = r12 |
1521 |
| - P[1, 1] = np.sqrt(1.0 - r12 ** 2) |
1522 |
| - for mp1 in range(2, n): |
1523 |
| - beta -= 0.5 |
1524 |
| - y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size) |
1525 |
| - z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size) |
1526 |
| - z = z / np.sqrt(np.einsum("ij,ij->j", z, z)) |
1527 |
| - P[0:mp1, mp1] = np.sqrt(y) * z |
1528 |
| - P[mp1, mp1] = np.sqrt(1.0 - y) |
1529 |
| - C = np.einsum("ji...,jk...->...ik", P, P) |
1530 |
| - triu_idx = np.triu_indices(n, k=1) |
1531 |
| - return C[..., triu_idx[0], triu_idx[1]] |
| 1555 | + rv_op = lkjcorr |
1532 | 1556 |
|
1533 |
| - def random(self, point=None, size=None): |
1534 |
| - """ |
1535 |
| - Draw random values from LKJ distribution. |
| 1557 | + def __new__(cls, *args, **kwargs): |
| 1558 | + transform = kwargs.get("transform", UNSET) |
| 1559 | + if transform is UNSET: |
| 1560 | + kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0))) |
| 1561 | + return super().__new__(cls, *args, **kwargs) |
1536 | 1562 |
|
1537 |
| - Parameters |
1538 |
| - ---------- |
1539 |
| - point: dict, optional |
1540 |
| - Dict of variable values on which random values are to be |
1541 |
| - conditioned (uses default point if not specified). |
1542 |
| - size: int, optional |
1543 |
| - Desired size of random sample (returns one sample if not |
1544 |
| - specified). |
1545 |
| -
|
1546 |
| - Returns |
1547 |
| - ------- |
1548 |
| - array |
1549 |
| - """ |
1550 |
| - # n, eta = draw_values([self.n, self.eta], point=point, size=size) |
1551 |
| - # size = 1 if size is None else size |
1552 |
| - # samples = generate_samples(self._random, n, eta, broadcast_shape=(size,)) |
1553 |
| - # return samples |
| 1563 | + @classmethod |
| 1564 | + def dist(cls, n, eta, **kwargs): |
| 1565 | + n = at.as_tensor_variable(intX(n)) |
| 1566 | + eta = at.as_tensor_variable(floatX(eta)) |
| 1567 | + return super().dist([n, eta], **kwargs) |
1554 | 1568 |
|
1555 |
| - def logp(self, x): |
| 1569 | + def logp(value, n, eta): |
1556 | 1570 | """
|
1557 | 1571 | Calculate log-probability of LKJ distribution at specified
|
1558 | 1572 | value.
|
1559 | 1573 |
|
1560 | 1574 | Parameters
|
1561 | 1575 | ----------
|
1562 |
| - x: numeric |
| 1576 | + value: numeric |
1563 | 1577 | Value for which log-probability is calculated.
|
1564 | 1578 |
|
1565 | 1579 | Returns
|
1566 | 1580 | -------
|
1567 | 1581 | TensorVariable
|
1568 | 1582 | """
|
1569 |
| - n = self.n |
1570 |
| - eta = self.eta |
1571 | 1583 |
|
1572 |
| - X = x[self.tri_index] |
1573 |
| - X = at.fill_diagonal(X, 1) |
| 1584 | + # TODO: Aesara does not have a `triu_indices`, so we can only work with constant |
| 1585 | + # n (or else find a different expression) |
| 1586 | + if not isinstance(n, Constant): |
| 1587 | + raise NotImplementedError("logp only implemented for constant `n`") |
| 1588 | + |
| 1589 | + n = int(n.data) |
| 1590 | + shape = n * (n - 1) // 2 |
| 1591 | + tri_index = np.zeros((n, n), dtype="int32") |
| 1592 | + tri_index[np.triu_indices(n, k=1)] = np.arange(shape) |
| 1593 | + tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) |
1574 | 1594 |
|
| 1595 | + value = at.take(value, tri_index) |
| 1596 | + value = at.fill_diagonal(value, 1) |
| 1597 | + |
| 1598 | + # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants |
| 1599 | + if not isinstance(eta, Constant): |
| 1600 | + raise NotImplementedError("logp only implemented for constant `eta`") |
| 1601 | + eta = float(eta.data) |
1575 | 1602 | result = _lkj_normalizing_constant(eta, n)
|
1576 |
| - result += (eta - 1.0) * at.log(det(X)) |
| 1603 | + result += (eta - 1.0) * at.log(det(value)) |
1577 | 1604 | return check_parameters(
|
1578 | 1605 | result,
|
1579 |
| - X >= -1, |
1580 |
| - X <= 1, |
1581 |
| - matrix_pos_def(X), |
| 1606 | + value >= -1, |
| 1607 | + value <= 1, |
| 1608 | + matrix_pos_def(value), |
1582 | 1609 | eta > 0,
|
1583 | 1610 | )
|
1584 | 1611 |
|
|
0 commit comments