Skip to content

Commit 5fd94ca

Browse files
committed
Refactor LKJCorr distribution to V4
1 parent 2daa766 commit 5fd94ca

File tree

3 files changed

+156
-113
lines changed

3 files changed

+156
-113
lines changed

pymc/distributions/multivariate.py

Lines changed: 116 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import scipy
2626

2727
from aesara.assert_op import Assert
28-
from aesara.graph.basic import Apply
28+
from aesara.graph.basic import Apply, Constant
2929
from aesara.graph.op import Op
3030
from aesara.sparse.basic import sp_sum
3131
from aesara.tensor import gammaln, sigmoid
@@ -43,7 +43,12 @@
4343

4444
from pymc.aesaraf import floatX, intX
4545
from pymc.distributions import transforms
46-
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
46+
from pymc.distributions.continuous import (
47+
BoundedContinuous,
48+
ChiSquared,
49+
Normal,
50+
assert_negative_support,
51+
)
4752
from pymc.distributions.dist_math import (
4853
betaln,
4954
check_parameters,
@@ -57,7 +62,9 @@
5762
rv_size_is_none,
5863
to_tuple,
5964
)
65+
from pymc.distributions.transforms import interval
6066
from pymc.math import kron_diag, kron_dot
67+
from pymc.util import UNSET
6168

6269
__all__ = [
6370
"MvNormal",
@@ -1079,6 +1086,11 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10791086

10801087

10811088
def _lkj_normalizing_constant(eta, n):
1089+
# TODO: This is mixing python branching with the potentially symbolic n and eta variables
1090+
if not isinstance(eta, (int, float)):
1091+
raise NotImplementedError("eta must be an int or float")
1092+
if not isinstance(n, int):
1093+
raise NotImplementedError("n must be an integer")
10821094
if eta == 1:
10831095
result = gammaln(2.0 * at.arange(1, int((n - 1) / 2) + 1)).sum()
10841096
if n % 2 == 1:
@@ -1431,7 +1443,74 @@ def LKJCholeskyCov(name, eta, n, sd_dist, compute_corr=False, store_in_trace=Tru
14311443
return chol, corr, stds
14321444

14331445

1434-
class LKJCorr(Continuous):
1446+
class LKJCorrRV(RandomVariable):
1447+
name = "lkjcorr"
1448+
ndim_supp = 1
1449+
ndims_params = [0, 0]
1450+
dtype = "floatX"
1451+
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")
1452+
1453+
def make_node(self, rng, size, dtype, n, eta):
1454+
n = at.as_tensor_variable(n)
1455+
if not n.ndim == 0:
1456+
raise ValueError("n must be a scalar (ndim=0).")
1457+
1458+
eta = at.as_tensor_variable(eta)
1459+
if not eta.ndim == 0:
1460+
raise ValueError("eta must be a scalar (ndim=0).")
1461+
1462+
return super().make_node(rng, size, dtype, n, eta)
1463+
1464+
def _shape_from_params(self, dist_params, **kwargs):
1465+
n = dist_params[0]
1466+
dist_shape = ((n * (n - 1)) // 2,)
1467+
return dist_shape
1468+
1469+
@classmethod
1470+
def rng_fn(cls, rng, n, eta, size):
1471+
1472+
# We flatten the size to make operations easier, and then rebuild it
1473+
if size is None:
1474+
flat_size = 1
1475+
else:
1476+
flat_size = np.prod(size)
1477+
1478+
C = cls._random_corr_matrix(rng, n, eta, flat_size)
1479+
1480+
triu_idx = np.triu_indices(n, k=1)
1481+
samples = C[..., triu_idx[0], triu_idx[1]]
1482+
1483+
if size is None:
1484+
samples = samples[0]
1485+
else:
1486+
dist_shape = (n * (n - 1)) // 2
1487+
samples = np.reshape(samples, (*size, dist_shape))
1488+
return samples
1489+
1490+
@classmethod
1491+
def _random_corr_matrix(cls, rng, n, eta, flat_size):
1492+
# original implementation in R see:
1493+
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1494+
beta = eta - 1.0 + n / 2.0
1495+
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=flat_size, random_state=rng) - 1.0
1496+
P = np.full((flat_size, n, n), np.eye(n))
1497+
P[..., 0, 1] = r12
1498+
P[..., 1, 1] = np.sqrt(1.0 - r12 ** 2)
1499+
for mp1 in range(2, n):
1500+
beta -= 0.5
1501+
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=flat_size, random_state=rng)
1502+
z = stats.norm.rvs(loc=0, scale=1, size=(flat_size, mp1), random_state=rng)
1503+
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis]
1504+
P[..., 0:mp1, mp1] = np.sqrt(y[..., np.newaxis]) * z
1505+
P[..., mp1, mp1] = np.sqrt(1.0 - y)
1506+
C = np.einsum("...ji,...jk->...ik", P, P)
1507+
return C
1508+
1509+
1510+
lkjcorr = LKJCorrRV()
1511+
1512+
1513+
class LKJCorr(BoundedContinuous):
14351514
r"""
14361515
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
14371516
@@ -1473,112 +1552,60 @@ class LKJCorr(Continuous):
14731552
100(9), pp.1989-2001.
14741553
"""
14751554

1476-
def __init__(self, eta=None, n=None, p=None, transform="interval", *args, **kwargs):
1477-
if (p is not None) and (n is not None) and (eta is None):
1478-
warnings.warn(
1479-
"Parameters to LKJCorr have changed: shape parameter n -> eta "
1480-
"dimension parameter p -> n. Please update your code. "
1481-
"Automatically re-assigning parameters for backwards compatibility.",
1482-
FutureWarning,
1483-
)
1484-
self.n = p
1485-
self.eta = n
1486-
eta = self.eta
1487-
n = self.n
1488-
elif (n is not None) and (eta is not None) and (p is None):
1489-
self.n = n
1490-
self.eta = eta
1491-
else:
1492-
raise ValueError(
1493-
"Invalid parameter: please use eta as the shape parameter and "
1494-
"n as the dimension parameter."
1495-
)
1496-
1497-
shape = n * (n - 1) // 2
1498-
self.mean = floatX(np.zeros(shape))
1499-
1500-
if transform == "interval":
1501-
transform = transforms.interval(-1, 1)
1502-
1503-
super().__init__(shape=shape, transform=transform, *args, **kwargs)
1504-
warnings.warn(
1505-
"Parameters in LKJCorr have been rename: shape parameter n -> eta "
1506-
"dimension parameter p -> n. Please double check your initialization.",
1507-
FutureWarning,
1508-
)
1509-
self.tri_index = np.zeros([n, n], dtype="int32")
1510-
self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1511-
self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1512-
1513-
def _random(self, n, eta, size=None):
1514-
size = size if isinstance(size, tuple) else (size,)
1515-
# original implementation in R see:
1516-
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1517-
beta = eta - 1.0 + n / 2.0
1518-
r12 = 2.0 * stats.beta.rvs(a=beta, b=beta, size=size) - 1.0
1519-
P = np.eye(n)[:, :, np.newaxis] * np.ones(size)
1520-
P[0, 1] = r12
1521-
P[1, 1] = np.sqrt(1.0 - r12 ** 2)
1522-
for mp1 in range(2, n):
1523-
beta -= 0.5
1524-
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size)
1525-
z = stats.norm.rvs(loc=0, scale=1, size=(mp1,) + size)
1526-
z = z / np.sqrt(np.einsum("ij,ij->j", z, z))
1527-
P[0:mp1, mp1] = np.sqrt(y) * z
1528-
P[mp1, mp1] = np.sqrt(1.0 - y)
1529-
C = np.einsum("ji...,jk...->...ik", P, P)
1530-
triu_idx = np.triu_indices(n, k=1)
1531-
return C[..., triu_idx[0], triu_idx[1]]
1555+
rv_op = lkjcorr
15321556

1533-
def random(self, point=None, size=None):
1534-
"""
1535-
Draw random values from LKJ distribution.
1557+
def __new__(cls, *args, **kwargs):
1558+
transform = kwargs.get("transform", UNSET)
1559+
if transform is UNSET:
1560+
kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0)))
1561+
return super().__new__(cls, *args, **kwargs)
15361562

1537-
Parameters
1538-
----------
1539-
point: dict, optional
1540-
Dict of variable values on which random values are to be
1541-
conditioned (uses default point if not specified).
1542-
size: int, optional
1543-
Desired size of random sample (returns one sample if not
1544-
specified).
1545-
1546-
Returns
1547-
-------
1548-
array
1549-
"""
1550-
# n, eta = draw_values([self.n, self.eta], point=point, size=size)
1551-
# size = 1 if size is None else size
1552-
# samples = generate_samples(self._random, n, eta, broadcast_shape=(size,))
1553-
# return samples
1563+
@classmethod
1564+
def dist(cls, n, eta, **kwargs):
1565+
n = at.as_tensor_variable(intX(n))
1566+
eta = at.as_tensor_variable(floatX(eta))
1567+
return super().dist([n, eta], **kwargs)
15541568

1555-
def logp(self, x):
1569+
def logp(value, n, eta):
15561570
"""
15571571
Calculate log-probability of LKJ distribution at specified
15581572
value.
15591573
15601574
Parameters
15611575
----------
1562-
x: numeric
1576+
value: numeric
15631577
Value for which log-probability is calculated.
15641578
15651579
Returns
15661580
-------
15671581
TensorVariable
15681582
"""
1569-
n = self.n
1570-
eta = self.eta
15711583

1572-
X = x[self.tri_index]
1573-
X = at.fill_diagonal(X, 1)
1584+
# TODO: Aesara does not have a `triu_indices`, so we can only work with constant
1585+
# n (or else find a different expression)
1586+
if not isinstance(n, Constant):
1587+
raise NotImplementedError("logp only implemented for constant `n`")
1588+
1589+
n = int(n.data)
1590+
shape = n * (n - 1) // 2
1591+
tri_index = np.zeros((n, n), dtype="int32")
1592+
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1593+
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
15741594

1595+
value = at.take(value, tri_index)
1596+
value = at.fill_diagonal(value, 1)
1597+
1598+
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
1599+
if not isinstance(eta, Constant):
1600+
raise NotImplementedError("logp only implemented for constant `eta`")
1601+
eta = float(eta.data)
15751602
result = _lkj_normalizing_constant(eta, n)
1576-
result += (eta - 1.0) * at.log(det(X))
1603+
result += (eta - 1.0) * at.log(det(value))
15771604
return check_parameters(
15781605
result,
1579-
X >= -1,
1580-
X <= 1,
1581-
matrix_pos_def(X),
1606+
value >= -1,
1607+
value <= 1,
1608+
matrix_pos_def(value),
15821609
eta > 0,
15831610
)
15841611

pymc/tests/test_distributions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,8 +2094,7 @@ def test_wishart(self, n):
20942094
)
20952095

20962096
@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
2097-
@pytest.mark.xfail(reason="Distribution not refactored yet")
2098-
def test_lkj(self, x, eta, n, lp):
2097+
def test_lkjcorr(self, x, eta, n, lp):
20992098
with Model() as model:
21002099
LKJCorr("lkj", eta=eta, n=n, transform=None)
21012100

pymc/tests/test_distributions_random.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,29 +1828,46 @@ def kronecker_rng_fn(self, size, mu, covs=None, sigma=None, rng=None):
18281828
]
18291829

18301830

1831-
class TestScalarParameterSamples(SeededTest):
1832-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1833-
def test_lkj(self):
1834-
for n in [2, 10, 50]:
1835-
# pylint: disable=cell-var-from-loop
1836-
shape = n * (n - 1) // 2
1837-
1838-
def ref_rand(size, eta):
1839-
beta = eta - 1 + n / 2
1840-
return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
1841-
1842-
class TestedLKJCorr(pm.LKJCorr):
1843-
def __init__(self, **kwargs):
1844-
kwargs.pop("shape", None)
1845-
super().__init__(n=n, **kwargs)
1846-
1847-
pymc_random(
1848-
TestedLKJCorr,
1849-
{"eta": Domain([1.0, 10.0, 100.0])},
1850-
size=10000 // n,
1851-
ref_rand=ref_rand,
1852-
)
1831+
class TestLKJCorr(BaseTestDistributionRandom):
1832+
pymc_dist = pm.LKJCorr
1833+
pymc_dist_params = {"n": 3, "eta": 1.0}
1834+
expected_rv_op_params = {"n": 3, "eta": 1.0}
1835+
1836+
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
1837+
sizes_expected = [
1838+
(3,),
1839+
(3,),
1840+
(1, 3),
1841+
(1, 3),
1842+
(5, 3),
1843+
(4, 5, 3),
1844+
(2, 4, 2, 3),
1845+
]
1846+
1847+
tests_to_run = [
1848+
"check_pymc_params_match_rv_op",
1849+
"check_rv_size",
1850+
"check_draws_match_expected",
1851+
]
18531852

1853+
def check_draws_match_expected(self):
1854+
def ref_rand(size, n, eta):
1855+
shape = int(n * (n - 1) // 2)
1856+
beta = eta - 1 + n / 2
1857+
return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
1858+
1859+
pymc_random(
1860+
pm.LKJCorr,
1861+
{
1862+
"n": Domain([2, 10, 50], edges=(None, None)),
1863+
"eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
1864+
},
1865+
ref_rand=ref_rand,
1866+
size=1000,
1867+
)
1868+
1869+
1870+
class TestScalarParameterSamples(SeededTest):
18541871
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
18551872
def test_normalmixture(self):
18561873
def ref_rand(size, w, mu, sigma):

0 commit comments

Comments
 (0)