Skip to content

Commit 67de240

Browse files
fonnesbecktwiecki
authored andcommitted
Add LaTeX repr for distributions (#2201)
* Added LaTeX repr for all continuous variables * Fixed indentation bug in StudentT _repr_latex_ * Replaced old-style with new-style formatting * Avoid storing bound methods as variables to prevent pickling problems * Cleaned up whitespace * Added LaTeX repr to discrete, mixture, MV and time series distributions * Removed whitespace from continuous.py * Removed whitespace from distributions files * Another attempt to remove whitespace * Trailing white spaces removed for sure this time
1 parent 565e26f commit 67de240

File tree

9 files changed

+623
-156
lines changed

9 files changed

+623
-156
lines changed

pymc3/distributions/continuous.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pymc3.theanof import floatX
1717
from . import transforms
18+
from pymc3.util import get_variable_name
1819

1920
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, DifferentiableSplineWrapper
2021
from .distribution import Continuous, draw_values, generate_samples, Bound
@@ -152,6 +153,15 @@ def logp(self, value):
152153
return bound(-tt.log(upper - lower),
153154
value >= lower, value <= upper)
154155

156+
def _repr_latex_(self, name=None, dist=None):
157+
if dist is None:
158+
dist = self
159+
lower = dist.lower
160+
upper = dist.upper
161+
return r'${} \sim \text{{Uniform}}(\mathit{{lower}}={}, \mathit{{upper}}={})$'.format(name,
162+
get_variable_name(lower),
163+
get_variable_name(upper))
164+
155165

156166
class Flat(Continuous):
157167
"""
@@ -169,6 +179,11 @@ def random(self, point=None, size=None, repeat=None):
169179
def logp(self, value):
170180
return tt.zeros_like(value)
171181

182+
def _repr_latex_(self, name=None, dist=None):
183+
if dist is None:
184+
dist = self
185+
return r'${} \sim \text{{Flat}()$'
186+
172187

173188
class Normal(Continuous):
174189
R"""
@@ -232,6 +247,15 @@ def logp(self, value):
232247
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
233248
sd > 0)
234249

250+
def _repr_latex_(self, name=None, dist=None):
251+
if dist is None:
252+
dist = self
253+
sd = dist.sd
254+
mu = dist.mu
255+
return r'${} \sim \text{{Normal}}(\mathit{{mu}}={}, \mathit{{sd}}={})$'.format(name,
256+
get_variable_name(mu),
257+
get_variable_name(sd))
258+
235259

236260
class HalfNormal(PositiveContinuous):
237261
R"""
@@ -283,6 +307,12 @@ def logp(self, value):
283307
value >= 0,
284308
tau > 0, sd > 0)
285309

310+
def _repr_latex_(self, name=None, dist=None):
311+
if dist is None:
312+
dist = self
313+
sd = dist.sd
314+
return r'${} \sim \text{{HalfNormal}}(\mathit{{sd}}={})$'.format(name,
315+
get_variable_name(sd))
286316

287317
class Wald(PositiveContinuous):
288318
R"""
@@ -404,6 +434,17 @@ def logp(self, value):
404434
value > 0, value - alpha > 0,
405435
mu > 0, lam > 0, alpha >= 0)
406436

437+
def _repr_latex_(self, name=None, dist=None):
438+
if dist is None:
439+
dist = self
440+
lam = dist.lam
441+
mu = dist.mu
442+
alpha = dist.alpha
443+
return r'${} \sim \text{{Wald}}(\mathit{{mu}}={}, \mathit{{lam}}={}, \mathit{{alpha}}={})$'.format(name,
444+
get_variable_name(mu),
445+
get_variable_name(lam),
446+
get_variable_name(alpha))
447+
407448

408449
class Beta(UnitContinuous):
409450
R"""
@@ -492,6 +533,15 @@ def logp(self, value):
492533
value >= 0, value <= 1,
493534
alpha > 0, beta > 0)
494535

536+
def _repr_latex_(self, name=None, dist=None):
537+
if dist is None:
538+
dist = self
539+
alpha = dist.alpha
540+
beta = dist.beta
541+
return r'${} \sim \text{{Beta}}(\mathit{{alpha}}={}, \mathit{{alpha}}={})$'.format(name,
542+
get_variable_name(alpha),
543+
get_variable_name(beta))
544+
495545

496546
class Exponential(PositiveContinuous):
497547
R"""
@@ -534,6 +584,12 @@ def logp(self, value):
534584
lam = self.lam
535585
return bound(tt.log(lam) - lam * value, value > 0, lam > 0)
536586

587+
def _repr_latex_(self, name=None, dist=None):
588+
if dist is None:
589+
dist = self
590+
lam = dist.lam
591+
return r'${} \sim \text{{Exponential}}(\mathit{{lam}}={})$'.format(name,
592+
get_variable_name(lam))
537593

538594
class Laplace(Continuous):
539595
R"""
@@ -579,6 +635,15 @@ def logp(self, value):
579635

580636
return -tt.log(2 * b) - abs(value - mu) / b
581637

638+
def _repr_latex_(self, name=None, dist=None):
639+
if dist is None:
640+
dist = self
641+
b = dist.b
642+
mu = dist.mu
643+
return r'${} \sim \text{{Laplace}}(\mathit{{mu}}={}, \mathit{{b}}={})$'.format(name,
644+
get_variable_name(mu),
645+
get_variable_name(b))
646+
582647

583648
class Lognormal(PositiveContinuous):
584649
R"""
@@ -643,6 +708,15 @@ def logp(self, value):
643708
- tt.log(value),
644709
tau > 0)
645710

711+
def _repr_latex_(self, name=None, dist=None):
712+
if dist is None:
713+
dist = self
714+
tau = dist.tau
715+
mu = dist.mu
716+
return r'${} \sim \text{{Lognormal}}(\mathit{{mu}}={}, \mathit{{tau}}={})$'.format(name,
717+
get_variable_name(mu),
718+
get_variable_name(tau))
719+
646720

647721
class StudentT(Continuous):
648722
R"""
@@ -707,6 +781,17 @@ def logp(self, value):
707781
- (nu + 1.0) / 2.0 * tt.log1p(lam * (value - mu)**2 / nu),
708782
lam > 0, nu > 0, sd > 0)
709783

784+
def _repr_latex_(self, name=None, dist=None):
785+
if dist is None:
786+
dist = self
787+
nu = dist.nu
788+
mu = dist.mu
789+
lam = dist.lam
790+
return r'${} \sim \text{{StudentT}}(\mathit{{nu}}={}, \mathit{{mu}}={}, \mathit{{lam}}={})$'.format(name,
791+
get_variable_name(nu),
792+
get_variable_name(mu),
793+
get_variable_name(lam))
794+
710795

711796
class Pareto(PositiveContinuous):
712797
R"""
@@ -769,6 +854,15 @@ def logp(self, value):
769854
- logpow(value, alpha + 1),
770855
value >= m, alpha > 0, m > 0)
771856

857+
def _repr_latex_(self, name=None, dist=None):
858+
if dist is None:
859+
dist = self
860+
alpha = dist.alpha
861+
m = dist.m
862+
return r'${} \sim \text{{Pareto}}(\mathit{{alpha}}={}, \mathit{{m}}={})$'.format(name,
863+
get_variable_name(alpha),
864+
get_variable_name(m))
865+
772866

773867
class Cauchy(Continuous):
774868
R"""
@@ -821,6 +915,15 @@ def logp(self, value):
821915
- tt.log1p(((value - alpha) / beta)**2),
822916
beta > 0)
823917

918+
def _repr_latex_(self, name=None, dist=None):
919+
if dist is None:
920+
dist = self
921+
alpha = dist.alpha
922+
beta = dist.beta
923+
return r'${} \sim \text{{Cauchy}}(\mathit{{alpha}}={}, \mathit{{beta}}={})$'.format(name,
924+
get_variable_name(alpha),
925+
get_variable_name(beta))
926+
824927

825928
class HalfCauchy(PositiveContinuous):
826929
R"""
@@ -867,6 +970,12 @@ def logp(self, value):
867970
- tt.log1p((value / beta)**2),
868971
value >= 0, beta > 0)
869972

973+
def _repr_latex_(self, name=None, dist=None):
974+
if dist is None:
975+
dist = self
976+
beta = dist.beta
977+
return r'${} \sim \text{{HalfCauchy}}(\mathit{{beta}}={})$'.format(name,
978+
get_variable_name(beta))
870979

871980
class Gamma(PositiveContinuous):
872981
R"""
@@ -950,6 +1059,15 @@ def logp(self, value):
9501059
alpha > 0,
9511060
beta > 0)
9521061

1062+
def _repr_latex_(self, name=None, dist=None):
1063+
if dist is None:
1064+
dist = self
1065+
beta = dist.beta
1066+
alpha = dist.alpha
1067+
return r'${} \sim \text{{Gamma}}(\mathit{{alpha}}={}, \mathit{{beta}}={})$'.format(name,
1068+
get_variable_name(alpha),
1069+
get_variable_name(beta))
1070+
9531071

9541072
class InverseGamma(PositiveContinuous):
9551073
R"""
@@ -1011,6 +1129,15 @@ def logp(self, value):
10111129
+ logpow(value, -alpha - 1),
10121130
value > 0, alpha > 0, beta > 0)
10131131

1132+
def _repr_latex_(self, name=None, dist=None):
1133+
if dist is None:
1134+
dist = self
1135+
beta = dist.beta
1136+
alpha = dist.alpha
1137+
return r'${} \sim \text{{InverseGamma}}(\mathit{{alpha}}={}, \mathit{{beta}}={})$'.format(name,
1138+
get_variable_name(alpha),
1139+
get_variable_name(beta))
1140+
10141141

10151142
class ChiSquared(Gamma):
10161143
R"""
@@ -1037,6 +1164,13 @@ def __init__(self, nu, *args, **kwargs):
10371164
super(ChiSquared, self).__init__(alpha=nu / 2., beta=0.5,
10381165
*args, **kwargs)
10391166

1167+
def _repr_latex_(self, name=None, dist=None):
1168+
if dist is None:
1169+
dist = self
1170+
nu = dist.nu
1171+
return r'${} \sim \Chi^2(\mathit{{nu}}={})$'.format(name,
1172+
get_variable_name(nu))
1173+
10401174

10411175
class Weibull(PositiveContinuous):
10421176
R"""
@@ -1093,6 +1227,15 @@ def logp(self, value):
10931227
- (value / beta)**alpha,
10941228
value >= 0, alpha > 0, beta > 0)
10951229

1230+
def _repr_latex_(self, name=None, dist=None):
1231+
if dist is None:
1232+
dist = self
1233+
beta = dist.beta
1234+
alpha = dist.alpha
1235+
return r'${} \sim \text{{Weibull}}(\mathit{{alpha}}={}, \mathit{{beta}}={})$'.format(name,
1236+
get_variable_name(alpha),
1237+
get_variable_name(beta))
1238+
10961239

10971240
def StudentTpos(*args, **kwargs):
10981241
warnings.warn("StudentTpos has been deprecated. In future, use HalfStudentT instead.",
@@ -1183,6 +1326,17 @@ def logp(self, value):
11831326
- 0.5 * ((value - mu) / sigma)**2)
11841327
return bound(lp, sigma > 0., nu > 0.)
11851328

1329+
def _repr_latex_(self, name=None, dist=None):
1330+
if dist is None:
1331+
dist = self
1332+
sigma = dist.sigma
1333+
mu = dist.mu
1334+
nu = dist.nu
1335+
return r'${} \sim \text{{ExGaussian}}(\mathit{{mu}}={}, \mathit{{sigma}}={}, \mathit{{nu}}={})$'.format(name,
1336+
get_variable_name(mu),
1337+
get_variable_name(sigma),
1338+
get_variable_name(nu))
1339+
11861340

11871341
class VonMises(Continuous):
11881342
R"""
@@ -1231,6 +1385,16 @@ def logp(self, value):
12311385
kappa = self.kappa
12321386
return bound(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
12331387

1388+
def _repr_latex_(self, name=None, dist=None):
1389+
if dist is None:
1390+
dist = self
1391+
kappa = dist.kappa
1392+
mu = dist.mu
1393+
return r'${} \sim \text{{VonMises}}(\mathit{{mu}}={}, \mathit{{kappa}}={})$'.format(name,
1394+
get_variable_name(mu),
1395+
get_variable_name(kappa))
1396+
1397+
12341398

12351399
class SkewNormal(Continuous):
12361400
R"""
@@ -1306,6 +1470,17 @@ def logp(self, value):
13061470
+ tt.log(tau / np.pi / 2.)) / 2.,
13071471
tau > 0, sd > 0)
13081472

1473+
def _repr_latex_(self, name=None, dist=None):
1474+
if dist is None:
1475+
dist = self
1476+
sd = dist.sd
1477+
mu = dist.mu
1478+
alpha = dist.alpha
1479+
return r'${} \sim \text{{Skew-Normal}}(\mathit{{mu}}={}, \mathit{{sd}}={}, \mathit{{alpha}}={})$'.format(name,
1480+
get_variable_name(mu),
1481+
get_variable_name(sd),
1482+
get_variable_name(alpha))
1483+
13091484

13101485
class Triangular(Continuous):
13111486
"""
@@ -1348,6 +1523,18 @@ def logp(self, value):
13481523
tt.switch(alltrue_elemwise([c < value, value <= upper]),
13491524
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),np.inf)))
13501525

1526+
def _repr_latex_(self, name=None, dist=None):
1527+
if dist is None:
1528+
dist = self
1529+
lower = dist.lower
1530+
upper = dist.upper
1531+
c = dist.c
1532+
return r'${} \sim \text{{Triangular}}(\mathit{{c}}={}, \mathit{{lower}}={}, \mathit{{upper}}={})$'.format(name,
1533+
get_variable_name(c),
1534+
get_variable_name(lower),
1535+
get_variable_name(upper))
1536+
1537+
13511538
class Gumbel(Continuous):
13521539
R"""
13531540
Univariate Gumbel log-likelihood
@@ -1391,6 +1578,16 @@ def logp(self, value):
13911578
scaled = (value - self.mu) / self.beta
13921579
return bound(-scaled - tt.exp(-scaled) - tt.log(self.beta), self.beta > 0)
13931580

1581+
def _repr_latex_(self, name=None, dist=None):
1582+
if dist is None:
1583+
dist = self
1584+
beta = dist.beta
1585+
mu = dist.mu
1586+
return r'${} \sim \text{{Gumbel}}(\mathit{{mu}}={}, \mathit{{beta}}={})$'.format(name,
1587+
get_variable_name(mu),
1588+
get_variable_name(beta))
1589+
1590+
13941591
class Interpolated(Continuous):
13951592
R"""
13961593
Univariate probability distribution defined as a linear interpolation

0 commit comments

Comments
 (0)