Skip to content

Commit e8f75fd

Browse files
committed
Add numerically safer ordered probit distribution.
1 parent c5c9a14 commit e8f75fd

File tree

4 files changed

+203
-2
lines changed

4 files changed

+203
-2
lines changed

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .discrete import Geometric
6666
from .discrete import Categorical
6767
from .discrete import OrderedLogistic
68+
from .discrete import OrderedProbit
6869

6970
from .distribution import DensityDist
7071
from .distribution import Distribution
@@ -143,6 +144,7 @@
143144
"Geometric",
144145
"Categorical",
145146
"OrderedLogistic",
147+
"OrderedProbit",
146148
"DensityDist",
147149
"Distribution",
148150
"Continuous",

pymc3/distributions/discrete.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717
from scipy import stats
1818
import warnings
1919

20-
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
20+
from .dist_math import (
21+
bound,
22+
factln,
23+
binomln,
24+
betaln,
25+
logpow,
26+
random_choice,
27+
normal_lcdf,
28+
normal_lccdf,
29+
log_diff_normal_cdf,
30+
)
2131
from .distribution import Discrete, draw_values, generate_samples
2232
from .shape_utils import broadcast_distribution_samples
2333
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
@@ -1517,3 +1527,131 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
15171527
p = p_cum[..., 1:] - p_cum[..., :-1]
15181528

15191529
super().__init__(p=p, *args, **kwargs)
1530+
1531+
1532+
class OrderedProbit(Categorical):
1533+
R"""
1534+
Ordered Probit log-likelihood.
1535+
1536+
Useful for regression on ordinal data values whose values range
1537+
from 1 to K as a function of some predictor, :math:`\eta`. The
1538+
cutpoints, :math:`c`, separate which ranges of :math:`\eta` are
1539+
mapped to which of the K observed dependent variables. The number
1540+
of cutpoints is K - 1. It is recommended that the cutpoints are
1541+
constrained to be ordered.
1542+
1543+
In order to stabilize the computation, log-likelihood is computed
1544+
in log space using the scaled error function `erfcx`.
1545+
1546+
.. math::
1547+
1548+
f(k \mid \eta, c) = \left\{
1549+
\begin{array}{l}
1550+
1 - \text{normal_cdf}(0, \sigma, \eta - c_1)
1551+
\,, \text{if } k = 0 \\
1552+
\text{normal_cdf}(0, \sigma, \eta - c_{k - 1}) -
1553+
\text{normal_cdf}(0, \sigma, \eta - c_{k})
1554+
\,, \text{if } 0 < k < K \\
1555+
\text{normal_cdf}(0, \sigma, \eta - c_{K - 1})
1556+
\,, \text{if } k = K \\
1557+
\end{array}
1558+
\right.
1559+
1560+
Parameters
1561+
----------
1562+
eta : float
1563+
The predictor.
1564+
c : array
1565+
The length K - 1 array of cutpoints which break :math:`\eta` into
1566+
ranges. Do not explicitly set the first and last elements of
1567+
:math:`c` to negative and positive infinity.
1568+
1569+
sigma: float
1570+
The standard deviation of probit function.
1571+
Example
1572+
--------
1573+
.. code:: python
1574+
1575+
# Generate data for a simple 1 dimensional example problem
1576+
n1_c = 300; n2_c = 300; n3_c = 300
1577+
cluster1 = np.random.randn(n1_c) + -1
1578+
cluster2 = np.random.randn(n2_c) + 0
1579+
cluster3 = np.random.randn(n3_c) + 2
1580+
1581+
x = np.concatenate((cluster1, cluster2, cluster3))
1582+
y = np.concatenate((1*np.ones(n1_c),
1583+
2*np.ones(n2_c),
1584+
3*np.ones(n3_c))) - 1
1585+
1586+
# Ordered logistic regression
1587+
with pm.Model() as model:
1588+
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
1589+
transform=pm.distributions.transforms.ordered)
1590+
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
1591+
tr = pm.sample(1000)
1592+
1593+
# Plot the results
1594+
plt.hist(cluster1, 30, alpha=0.5);
1595+
plt.hist(cluster2, 30, alpha=0.5);
1596+
plt.hist(cluster3, 30, alpha=0.5);
1597+
plt.hist(tr["cutpoints"][:,0], 80, alpha=0.2, color='k');
1598+
plt.hist(tr["cutpoints"][:,1], 80, alpha=0.2, color='k');
1599+
1600+
"""
1601+
1602+
def __init__(self, eta, cutpoints, *args, **kwargs):
1603+
1604+
self.eta = tt.as_tensor_variable(floatX(eta))
1605+
self.cutpoints = tt.as_tensor_variable(cutpoints)
1606+
1607+
probits = tt.shape_padright(self.eta) - self.cutpoints
1608+
_log_p = tt.concatenate(
1609+
[
1610+
tt.shape_padright(normal_lccdf(0, 1, probits[..., 0])),
1611+
log_diff_normal_cdf(0, 1, probits[..., :-1], probits[..., 1:]),
1612+
tt.shape_padright(normal_lcdf(0, 1, probits[..., -1])),
1613+
],
1614+
axis=-1,
1615+
)
1616+
_log_p = tt.as_tensor_variable(floatX(_log_p))
1617+
1618+
self._log_p = _log_p
1619+
self.mode = tt.argmax(_log_p, axis=-1)
1620+
p = tt.exp(_log_p)
1621+
1622+
super().__init__(p=p, *args, **kwargs)
1623+
1624+
def logp(self, value):
1625+
r"""
1626+
Calculate log-probability of Ordered Probit distribution at specified value.
1627+
1628+
Parameters
1629+
----------
1630+
value: numeric
1631+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
1632+
values are desired the values must be provided in a numpy array or theano tensor
1633+
1634+
Returns
1635+
-------
1636+
TensorVariable
1637+
"""
1638+
logp = self._log_p
1639+
k = self.k
1640+
1641+
# Clip values before using them for indexing
1642+
value_clip = tt.clip(value, 0, k - 1)
1643+
1644+
if logp.ndim > 1:
1645+
if logp.ndim > value_clip.ndim:
1646+
value_clip = tt.shape_padleft(value_clip, logp.ndim - value_clip.ndim)
1647+
elif logp.ndim < value_clip.ndim:
1648+
logp = tt.shape_padleft(logp, value_clip.ndim - logp.ndim)
1649+
pattern = (logp.ndim - 1,) + tuple(range(logp.ndim - 1))
1650+
a = take_along_axis(
1651+
logp.dimshuffle(pattern),
1652+
value_clip,
1653+
)
1654+
else:
1655+
a = logp[value_clip]
1656+
1657+
return bound(a, value >= 0, value <= (k - 1))

pymc3/distributions/dist_math.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,46 @@ def normal_lccdf(mu, sigma, x):
133133
)
134134

135135

136+
def log_diff_normal_cdf(mu, sigma, x, y):
137+
"""
138+
Compute :math:`log(\Phi(\frac{x - \mu}{\sigma}) - \Phi(\frac{y - \mu}{\sigma}))` safely in log space.
139+
140+
Parameters
141+
----------
142+
mu: float
143+
mean
144+
sigma: float
145+
std
146+
147+
x: float
148+
149+
y: float
150+
must be strictly less than x.
151+
152+
Returns
153+
-------
154+
log (\Phi(x) - \Phi(y))
155+
156+
"""
157+
x = (x - mu) / sigma / tt.sqrt(2.0)
158+
y = (y - mu) / sigma / tt.sqrt(2.0)
159+
160+
# To stabilize the computation, consider these three regions:
161+
# 1) x > y > 0 => Use erf(x) = 1 - e^{-x^2} erfcx(x) and erf(y) =1 - e^{-y^2} erfcx(y)
162+
# 2) 0 > x > 0 => Use erf(x) = e^{-x^2} erfcx(-x) and erf(y) = e^{-y^2} erfcx(-y)
163+
# 3) x > 0 > y => Naive formula log( (erf(x) - erf(y)) / 2 ) works fine.
164+
return tt.log(0.5) + tt.switch(
165+
tt.gt(y, 0),
166+
-tt.square(y) + tt.log(tt.erfcx(y) - tt.exp(tt.square(y) - tt.square(x)) * tt.erfcx(x)),
167+
tt.switch(
168+
tt.lt(x, 0), # 0 > x > y
169+
-tt.square(x)
170+
+ tt.log(tt.erfcx(-x) - tt.exp(tt.square(x) - tt.square(y)) * tt.erfcx(-y)),
171+
tt.log(tt.erf(x) - tt.erf(y)), # x >0 > y
172+
),
173+
)
174+
175+
136176
def sigma2rho(sigma):
137177
"""
138178
`sigma -> rho` theano converter

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
Gumbel,
6767
Logistic,
6868
OrderedLogistic,
69+
OrderedProbit,
6970
LogitNormal,
7071
Interpolated,
7172
ZeroInflatedBinomial,
@@ -89,7 +90,7 @@
8990
from scipy import integrate
9091
import scipy.stats.distributions as sp
9192
import scipy.stats
92-
from scipy.special import logit
93+
from scipy.special import logit, erf, erfcx
9394
import theano
9495
import theano.tensor as tt
9596
from ..math import kronecker
@@ -429,6 +430,17 @@ def orderedlogistic_logpdf(value, eta, cutpoints):
429430
return np.where(np.all(ps >= 0), np.log(p), -np.inf)
430431

431432

433+
def invprobit(x):
434+
return (erf(x / np.sqrt(2)) + 1) / 2
435+
436+
437+
def orderedprobit_logpdf(value, eta, cutpoints):
438+
c = np.concatenate(([-np.inf], cutpoints, [np.inf]))
439+
ps = np.array([invprobit(eta - cc) - invprobit(eta - cc1) for cc, cc1 in zip(c[:-1], c[1:])])
440+
p = ps[value]
441+
return np.where(np.all(ps >= 0), np.log(p), -np.inf)
442+
443+
432444
class Simplex:
433445
def __init__(self, n):
434446
self.vals = list(simplex_values(n))
@@ -1516,6 +1528,15 @@ def test_orderedlogistic(self, n):
15161528
lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints),
15171529
)
15181530

1531+
@pytest.mark.parametrize("n", [2, 3, 4])
1532+
def test_orderedprobit(self, n):
1533+
self.pymc3_matches_scipy(
1534+
OrderedProbit,
1535+
Domain(range(n), "int64"),
1536+
{"eta": R, "cutpoints": SortedVector(n - 1)},
1537+
lambda value, eta, cutpoints: orderedprobit_logpdf(value, eta, cutpoints),
1538+
)
1539+
15191540
def test_densitydist(self):
15201541
def logp(x):
15211542
return -log(2 * 0.5) - abs(x - 0.5) / 0.5

0 commit comments

Comments
 (0)