Skip to content

Commit 2a1071d

Browse files
committed
Refactor OrderedLogistic
1 parent ea7afba commit 2a1071d

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

pymc3/distributions/discrete.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,11 +1717,14 @@ class OrderedLogistic(Categorical):
17171717
17181718
"""
17191719

1720-
def __init__(self, eta, cutpoints, *args, **kwargs):
1721-
self.eta = at.as_tensor_variable(floatX(eta))
1722-
self.cutpoints = at.as_tensor_variable(cutpoints)
1720+
rv_op = categorical
17231721

1724-
pa = sigmoid(self.cutpoints - at.shape_padright(self.eta))
1722+
@classmethod
1723+
def dist(cls, eta, cutpoints, *args, **kwargs):
1724+
eta = at.as_tensor_variable(floatX(eta))
1725+
cutpoints = at.as_tensor_variable(cutpoints)
1726+
1727+
pa = sigmoid(cutpoints - at.shape_padright(eta))
17251728
p_cum = at.concatenate(
17261729
[
17271730
at.zeros_like(at.shape_padright(pa[..., 0])),
@@ -1732,7 +1735,7 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
17321735
)
17331736
p = p_cum[..., 1:] - p_cum[..., :-1]
17341737

1735-
super().__init__(p=p, *args, **kwargs)
1738+
return super().dist(p, **kwargs)
17361739

17371740

17381741
class OrderedProbit(Categorical):

pymc3/tests/test_distributions.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
continuous,
102102
logcdf,
103103
logpt,
104+
logpt_sum,
104105
)
105106
from pymc3.math import kronecker, logsumexp
106107
from pymc3.model import Deterministic, Model, Point
@@ -2297,7 +2298,6 @@ def test_categorical(self, n):
22972298
)
22982299

22992300
@pytest.mark.parametrize("n", [2, 3, 4])
2300-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23012301
def test_orderedlogistic(self, n):
23022302
self.check_logp(
23032303
OrderedLogistic,
@@ -2738,27 +2738,34 @@ def test_discrete_trafo():
27382738
err.match("Transformations for discrete distributions")
27392739

27402740

2741+
# TODO: Is this test working as expected / still relevant?
27412742
@pytest.mark.parametrize("shape", [tuple(), (1,), (3, 1), (3, 2)], ids=str)
2742-
@pytest.mark.xfail(reason="Distribution not refactored yet")
27432743
def test_orderedlogistic_dimensions(shape):
27442744
# Test for issue #3535
27452745
loge = np.log10(np.exp(1))
27462746
size = 7
27472747
p = np.ones(shape + (10,)) / 10
27482748
cutpoints = np.tile(logit(np.linspace(0, 1, 11)[1:-1]), shape + (1,))
2749-
obs = np.random.randint(0, 1, size=(size,) + shape)
2749+
obs = np.random.randint(0, 2, size=(size,) + shape)
27502750
with Model():
27512751
ol = OrderedLogistic(
2752-
"ol", eta=np.zeros(shape), cutpoints=cutpoints, size=shape, observed=obs
2753-
)
2754-
c = Categorical("c", p=p, size=shape, observed=obs)
2755-
ologp = logpt(ol, 1).eval() * loge
2756-
clogp = logpt(c, 1) * loge
2752+
"ol",
2753+
eta=np.zeros(shape),
2754+
cutpoints=cutpoints,
2755+
observed=obs,
2756+
)
2757+
c = Categorical(
2758+
"c",
2759+
p=p,
2760+
observed=obs,
2761+
)
2762+
ologp = logpt_sum(ol, np.ones_like(obs)).eval() * loge
2763+
clogp = logpt_sum(c, np.ones_like(obs)).eval() * loge
27572764
expected = -np.prod((size,) + shape)
27582765

2759-
assert c.distribution.p.ndim == (len(shape) + 1)
2766+
assert c.owner.inputs[3].ndim == (len(shape) + 1)
27602767
assert np.allclose(clogp, expected)
2761-
assert ol.distribution.p.ndim == (len(shape) + 1)
2768+
assert ol.owner.inputs[3].ndim == (len(shape) + 1)
27622769
assert np.allclose(ologp, expected)
27632770

27642771

pymc3/tests/test_distributions_random.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,16 @@ def constant_rng_fn(self, size, c):
929929
]
930930

931931

932+
class TestOrderedLogistic(BaseTestDistribution):
933+
pymc_dist = pm.OrderedLogistic
934+
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}
935+
expected_rv_op_params = {"p": np.array([0.11920292, 0.38079708, 0.38079708, 0.11920292])}
936+
tests_to_run = [
937+
"check_pymc_params_match_rv_op",
938+
"check_rv_size",
939+
]
940+
941+
932942
class TestScalarParameterSamples(SeededTest):
933943
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
934944
def test_bounded(self):

0 commit comments

Comments
 (0)