Skip to content

Commit 9c40f8f

Browse files
committed
Refactor OrderedProbit
1 parent 2a1071d commit 9c40f8f

File tree

3 files changed

+17
-44
lines changed

3 files changed

+17
-44
lines changed

pymc3/distributions/discrete.py

Lines changed: 7 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,12 +1808,14 @@ class OrderedProbit(Categorical):
18081808
18091809
"""
18101810

1811-
def __init__(self, eta, cutpoints, *args, **kwargs):
1811+
rv_op = categorical
18121812

1813-
self.eta = at.as_tensor_variable(floatX(eta))
1814-
self.cutpoints = at.as_tensor_variable(cutpoints)
1813+
@classmethod
1814+
def dist(cls, eta, cutpoints, *args, **kwargs):
1815+
eta = at.as_tensor_variable(floatX(eta))
1816+
cutpoints = at.as_tensor_variable(cutpoints)
18151817

1816-
probits = at.shape_padright(self.eta) - self.cutpoints
1818+
probits = at.shape_padright(eta) - cutpoints
18171819
_log_p = at.concatenate(
18181820
[
18191821
at.shape_padright(normal_lccdf(0, 1, probits[..., 0])),
@@ -1823,44 +1825,6 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
18231825
axis=-1,
18241826
)
18251827
_log_p = at.as_tensor_variable(floatX(_log_p))
1826-
1827-
self._log_p = _log_p
1828-
self.mode = at.argmax(_log_p, axis=-1)
18291828
p = at.exp(_log_p)
18301829

1831-
super().__init__(p=p, *args, **kwargs)
1832-
1833-
def logp(self, value):
1834-
r"""
1835-
Calculate log-probability of Ordered Probit distribution at specified value.
1836-
1837-
Parameters
1838-
----------
1839-
value: numeric
1840-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1841-
values are desired the values must be provided in a numpy array or Aesara tensor
1842-
1843-
Returns
1844-
-------
1845-
TensorVariable
1846-
"""
1847-
logp = self._log_p
1848-
k = self.k
1849-
1850-
# Clip values before using them for indexing
1851-
value_clip = at.clip(value, 0, k - 1)
1852-
1853-
if logp.ndim > 1:
1854-
if logp.ndim > value_clip.ndim:
1855-
value_clip = at.shape_padleft(value_clip, logp.ndim - value_clip.ndim)
1856-
elif logp.ndim < value_clip.ndim:
1857-
logp = at.shape_padleft(logp, value_clip.ndim - logp.ndim)
1858-
pattern = (logp.ndim - 1,) + tuple(range(logp.ndim - 1))
1859-
a = take_along_axis(
1860-
logp.dimshuffle(pattern),
1861-
value_clip,
1862-
)
1863-
else:
1864-
a = logp[value_clip]
1865-
1866-
return bound(a, value >= 0, value <= (k - 1))
1830+
return super().dist(p, **kwargs)

pymc3/tests/test_distributions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2307,7 +2307,6 @@ def test_orderedlogistic(self, n):
23072307
)
23082308

23092309
@pytest.mark.parametrize("n", [2, 3, 4])
2310-
@pytest.mark.xfail(reason="Distribution not refactored yet")
23112310
def test_orderedprobit(self, n):
23122311
self.check_logp(
23132312
OrderedProbit,

pymc3/tests/test_distributions_random.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,16 @@ class TestOrderedLogistic(BaseTestDistribution):
939939
]
940940

941941

942+
class TestOrderedProbit(BaseTestDistribution):
943+
pymc_dist = pm.OrderedProbit
944+
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}
945+
expected_rv_op_params = {"p": np.array([0.02275013, 0.47724987, 0.47724987, 0.02275013])}
946+
tests_to_run = [
947+
"check_pymc_params_match_rv_op",
948+
"check_rv_size",
949+
]
950+
951+
942952
class TestScalarParameterSamples(SeededTest):
943953
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
944954
def test_bounded(self):

0 commit comments

Comments
 (0)