Skip to content

Commit 02babc0

Browse files
committed
Added LaTeX repr for all continuous variables
1 parent 2511a58 commit 02babc0

File tree

5 files changed

+231
-0
lines changed

5 files changed

+231
-0
lines changed

pymc3/distributions/continuous.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pymc3.theanof import floatX
1717
from . import transforms
18+
from pymc3.util import get_variable_name
1819

1920
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, DifferentiableSplineWrapper
2021
from .distribution import Continuous, draw_values, generate_samples, Bound
@@ -152,6 +153,15 @@ def logp(self, value):
152153
return bound(-tt.log(upper - lower),
153154
value >= lower, value <= upper)
154155

156+
def _repr_latex_(self, name=None, dist=None):
157+
if dist is None:
158+
dist = self
159+
lower = dist.lower
160+
upper = dist.upper
161+
return r'$%s \sim \text{Uniform}(\mathit{lower}=%s, \mathit{upper}=%s)$' % (name,
162+
get_variable_name(lower),
163+
get_variable_name(upper))
164+
155165

156166
class Flat(Continuous):
157167
"""
@@ -169,6 +179,11 @@ def random(self, point=None, size=None, repeat=None):
169179
def logp(self, value):
170180
return tt.zeros_like(value)
171181

182+
def _repr_latex_(self, name=None, dist=None):
183+
if dist is None:
184+
dist = self
185+
return r'$%s \sim \text{Flat}()$'
186+
172187

173188
class Normal(Continuous):
174189
R"""
@@ -232,6 +247,15 @@ def logp(self, value):
232247
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
233248
sd > 0)
234249

250+
def _repr_latex_(self, name=None, dist=None):
251+
if dist is None:
252+
dist = self
253+
sd = dist.sd
254+
mu = dist.mu
255+
return r'$%s \sim \text{Normal}(\mathit{mu}=%s, \mathit{sd}=%s)$' % (name,
256+
get_variable_name(mu),
257+
get_variable_name(sd))
258+
235259

236260
class HalfNormal(PositiveContinuous):
237261
R"""
@@ -283,6 +307,12 @@ def logp(self, value):
283307
value >= 0,
284308
tau > 0, sd > 0)
285309

310+
def _repr_latex_(self, name=None, dist=None):
311+
if dist is None:
312+
dist = self
313+
sd = dist.sd
314+
return r'$%s \sim \text{HalfNormal}(\mathit{sd}=%s)$' % (name,
315+
get_variable_name(sd))
286316

287317
class Wald(PositiveContinuous):
288318
R"""
@@ -404,6 +434,17 @@ def logp(self, value):
404434
value > 0, value - alpha > 0,
405435
mu > 0, lam > 0, alpha >= 0)
406436

437+
def _repr_latex_(self, name=None, dist=None):
438+
if dist is None:
439+
dist = self
440+
lam = dist.lam
441+
mu = dist.mu
442+
alpha = dist.alpha
443+
return r'$%s \sim \text{Wald}(\mathit{mu}=%s, \mathit{lam}=%s, \mathit{alpha}=%s)$' % (name,
444+
get_variable_name(mu),
445+
get_variable_name(lam),
446+
get_variable_name(alpha))
447+
407448

408449
class Beta(UnitContinuous):
409450
R"""
@@ -492,6 +533,15 @@ def logp(self, value):
492533
value >= 0, value <= 1,
493534
alpha > 0, beta > 0)
494535

536+
def _repr_latex_(self, name=None, dist=None):
537+
if dist is None:
538+
dist = self
539+
alpha = dist.alpha
540+
beta = dist.beta
541+
return r'$%s \sim \text{Beta}(\mathit{alpha}=%s, \mathit{alpha}=%s)$' % (name,
542+
get_variable_name(alpha),
543+
get_variable_name(beta))
544+
495545

496546
class Exponential(PositiveContinuous):
497547
R"""
@@ -534,6 +584,12 @@ def logp(self, value):
534584
lam = self.lam
535585
return bound(tt.log(lam) - lam * value, value > 0, lam > 0)
536586

587+
def _repr_latex_(self, name=None, dist=None):
588+
if dist is None:
589+
dist = self
590+
lam = dist.lam
591+
return r'$%s \sim \text{Exponential}(\mathit{lam}=%s)$' % (name,
592+
get_variable_name(lam))
537593

538594
class Laplace(Continuous):
539595
R"""
@@ -579,6 +635,15 @@ def logp(self, value):
579635

580636
return -tt.log(2 * b) - abs(value - mu) / b
581637

638+
def _repr_latex_(self, name=None, dist=None):
639+
if dist is None:
640+
dist = self
641+
b = dist.b
642+
mu = dist.mu
643+
return r'$%s \sim \text{Laplace}(\mathit{mu}=%s, \mathit{b}=%s)$' % (name,
644+
get_variable_name(mu),
645+
get_variable_name(b))
646+
582647

583648
class Lognormal(PositiveContinuous):
584649
R"""
@@ -643,6 +708,15 @@ def logp(self, value):
643708
- tt.log(value),
644709
tau > 0)
645710

711+
def _repr_latex_(self, name=None, dist=None):
712+
if dist is None:
713+
dist = self
714+
tau = dist.tau
715+
mu = dist.mu
716+
return r'$%s \sim \text{Lognormal}(\mathit{mu}=%s, \mathit{tau}=%s)$' % (name,
717+
get_variable_name(mu),
718+
get_variable_name(tau))
719+
646720

647721
class StudentT(Continuous):
648722
R"""
@@ -707,6 +781,17 @@ def logp(self, value):
707781
- (nu + 1.0) / 2.0 * tt.log1p(lam * (value - mu)**2 / nu),
708782
lam > 0, nu > 0, sd > 0)
709783

784+
def _repr_latex_(self, name=None, dist=None):
785+
if dist is None:
786+
dist = self
787+
nu = dist.nu
788+
mu = dist.mu
789+
lam = dist.lam
790+
return r'$%s \sim \text{StudentT}(\mathit{nu}=%s, \mathit{mu}=%s, \mathit{lam}=%s)$' % (name,
791+
get_variable_name(nu),
792+
get_variable_name(mu),
793+
get_variable_name(lam))
794+
710795

711796
class Pareto(PositiveContinuous):
712797
R"""
@@ -769,6 +854,15 @@ def logp(self, value):
769854
- logpow(value, alpha + 1),
770855
value >= m, alpha > 0, m > 0)
771856

857+
def _repr_latex_(self, name=None, dist=None):
858+
if dist is None:
859+
dist = self
860+
alpha = dist.alpha
861+
m = dist.m
862+
return r'$%s \sim \text{Pareto}(\mathit{alpha}=%s, \mathit{m}=%s)$' % (name,
863+
get_variable_name(alpha),
864+
get_variable_name(m))
865+
772866

773867
class Cauchy(Continuous):
774868
R"""
@@ -821,6 +915,15 @@ def logp(self, value):
821915
- tt.log1p(((value - alpha) / beta)**2),
822916
beta > 0)
823917

918+
def _repr_latex_(self, name=None, dist=None):
919+
if dist is None:
920+
dist = self
921+
alpha = dist.alpha
922+
beta = dist.beta
923+
return r'$%s \sim \text{Cauchy}(\mathit{alpha}=%s, \mathit{beta}=%s)$' % (name,
924+
get_variable_name(alpha),
925+
get_variable_name(beta))
926+
824927

825928
class HalfCauchy(PositiveContinuous):
826929
R"""
@@ -867,6 +970,12 @@ def logp(self, value):
867970
- tt.log1p((value / beta)**2),
868971
value >= 0, beta > 0)
869972

973+
def _repr_latex_(self, name=None, dist=None):
974+
if dist is None:
975+
dist = self
976+
beta = dist.beta
977+
return r'$%s \sim \text{HalfCauchy}(\mathit{beta}=%s)$' % (name,
978+
get_variable_name(beta))
870979

871980
class Gamma(PositiveContinuous):
872981
R"""
@@ -950,6 +1059,15 @@ def logp(self, value):
9501059
alpha > 0,
9511060
beta > 0)
9521061

1062+
def _repr_latex_(self, name=None, dist=None):
1063+
if dist is None:
1064+
dist = self
1065+
beta = dist.beta
1066+
alpha = dist.alpha
1067+
return r'$%s \sim \text{Gamma}(\mathit{alpha}=%s, \mathit{beta}=%s)$' % (name,
1068+
get_variable_name(alpha),
1069+
get_variable_name(beta))
1070+
9531071

9541072
class InverseGamma(PositiveContinuous):
9551073
R"""
@@ -1011,6 +1129,15 @@ def logp(self, value):
10111129
+ logpow(value, -alpha - 1),
10121130
value > 0, alpha > 0, beta > 0)
10131131

1132+
def _repr_latex_(self, name=None, dist=None):
1133+
if dist is None:
1134+
dist = self
1135+
beta = dist.beta
1136+
alpha = dist.alpha
1137+
return r'$%s \sim \text{InverseGamma}(\mathit{alpha}=%s, \mathit{beta}=%s)$' % (name,
1138+
get_variable_name(alpha),
1139+
get_variable_name(beta))
1140+
10141141

10151142
class ChiSquared(Gamma):
10161143
R"""
@@ -1037,6 +1164,13 @@ def __init__(self, nu, *args, **kwargs):
10371164
super(ChiSquared, self).__init__(alpha=nu / 2., beta=0.5,
10381165
*args, **kwargs)
10391166

1167+
def _repr_latex_(self, name=None, dist=None):
1168+
if dist is None:
1169+
dist = self
1170+
nu = dist.nu
1171+
return r'$%s \sim \Chi^2(\mathit{nu}=%s)$' % (name,
1172+
get_variable_name(nu))
1173+
10401174

10411175
class Weibull(PositiveContinuous):
10421176
R"""
@@ -1093,6 +1227,15 @@ def logp(self, value):
10931227
- (value / beta)**alpha,
10941228
value >= 0, alpha > 0, beta > 0)
10951229

1230+
def _repr_latex_(self, name=None, dist=None):
1231+
if dist is None:
1232+
dist = self
1233+
beta = dist.beta
1234+
alpha = dist.alpha
1235+
return r'$%s \sim \text{Weibull}(\mathit{alpha}=%s, \mathit{beta}=%s)$' % (name,
1236+
get_variable_name(alpha),
1237+
get_variable_name(beta))
1238+
10961239

10971240
def StudentTpos(*args, **kwargs):
10981241
warnings.warn("StudentTpos has been deprecated. In future, use HalfStudentT instead.",
@@ -1183,6 +1326,17 @@ def logp(self, value):
11831326
- 0.5 * ((value - mu) / sigma)**2)
11841327
return bound(lp, sigma > 0., nu > 0.)
11851328

1329+
def _repr_latex_(self, name=None, dist=None):
1330+
if dist is None:
1331+
dist = self
1332+
sigma = dist.sigma
1333+
mu = dist.mu
1334+
nu = dist.nu
1335+
return r'$%s \sim \text{ExGaussian}(\mathit{mu}=%s, \mathit{sigma}=%s, \mathit{nu}=%s)$' % (name,
1336+
get_variable_name(mu),
1337+
get_variable_name(sigma),
1338+
get_variable_name(nu))
1339+
11861340

11871341
class VonMises(Continuous):
11881342
R"""
@@ -1230,6 +1384,16 @@ def logp(self, value):
12301384
mu = self.mu
12311385
kappa = self.kappa
12321386
return bound(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
1387+
1388+
def _repr_latex_(self, name=None, dist=None):
1389+
if dist is None:
1390+
dist = self
1391+
kappa = dist.kappa
1392+
mu = dist.mu
1393+
return r'$%s \sim \text{VonMises}(\mathit{mu}=%s, \mathit{kappa}=%s)$' % (name,
1394+
get_variable_name(mu),
1395+
get_variable_name(kappa))
1396+
12331397

12341398

12351399
class SkewNormal(Continuous):
@@ -1306,6 +1470,17 @@ def logp(self, value):
13061470
+ tt.log(tau / np.pi / 2.)) / 2.,
13071471
tau > 0, sd > 0)
13081472

1473+
def _repr_latex_(self, name=None, dist=None):
1474+
if dist is None:
1475+
dist = self
1476+
sd = dist.sd
1477+
mu = dist.mu
1478+
alpha = dist.alpha
1479+
return r'$%s \sim \text{Skew-Normal}(\mathit{mu}=%s, \mathit{sd}=%s, \mathit{alpha}=%s)$' % (name,
1480+
get_variable_name(mu),
1481+
get_variable_name(sd),
1482+
get_variable_name(alpha))
1483+
13091484

13101485
class Triangular(Continuous):
13111486
"""
@@ -1348,6 +1523,18 @@ def logp(self, value):
13481523
tt.switch(alltrue_elemwise([c < value, value <= upper]),
13491524
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),np.inf)))
13501525

1526+
def _repr_latex_(self, name=None, dist=None):
1527+
if dist is None:
1528+
dist = self
1529+
lower = dist.lower
1530+
upper = dist.upper
1531+
c = dist.c
1532+
return r'$%s \sim \text{Triangular}(\mathit{c}=%s, \mathit{lower}=%s, \mathit{upper}=%s)$' % (name,
1533+
get_variable_name(c),
1534+
get_variable_name(lower),
1535+
get_variable_name(upper))
1536+
1537+
13511538
class Gumbel(Continuous):
13521539
R"""
13531540
Univariate Gumbel log-likelihood
@@ -1391,6 +1578,16 @@ def logp(self, value):
13911578
scaled = (value - self.mu) / self.beta
13921579
return bound(-scaled - tt.exp(-scaled) - tt.log(self.beta), self.beta > 0)
13931580

1581+
def _repr_latex_(self, name=None, dist=None):
1582+
if dist is None:
1583+
dist = self
1584+
beta = dist.beta
1585+
mu = dist.mu
1586+
return r'$%s \sim \text{Gumbel}(\mathit{mu}=%s, \mathit{beta}=%s)$' % (name,
1587+
get_variable_name(mu),
1588+
get_variable_name(beta))
1589+
1590+
13941591
class Interpolated(Continuous):
13951592
R"""
13961593
Probability distribution defined as a linear interpolation of

pymc3/distributions/distribution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def getattr_value(self, val):
8585

8686
return val
8787

88+
def _repr_latex_(self, name=None, dist=None):
89+
return None
90+
8891

8992
def TensorType(dtype, shape):
9093
return tt.TensorType(str(dtype), np.atleast_1d(shape) == 1)

pymc3/distributions/transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from . import distribution
66
from ..math import logit, invlogit
77
import numpy as np
8+
from functools import partial
89

910
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
1011
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
@@ -69,6 +70,8 @@ def __init__(self, dist, transform, *args, **kwargs):
6970
b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False))
7071
# force the last dim not broadcastable
7172
self.type = tt.TensorType(v.dtype, b)
73+
74+
self._repr_latex_ = partial(dist._repr_latex_, dist=dist)
7275

7376
def logp(self, x):
7477
return (self.dist.logp(self.transform_used.backward(x)) +

0 commit comments

Comments
 (0)