Skip to content

Commit c1323b1

Browse files
fonnesbecktwiecki
authored andcommitted
Converted Gumbel distribution to v4
1 parent c4ccbee commit c1323b1

File tree

1 file changed

+45
-120
lines changed

1 file changed

+45
-120
lines changed

pymc3/distributions/continuous.py

Lines changed: 45 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020

2121
import aesara.tensor as at
2222
import numpy as np
23+
from typing import Union
2324

2425
from aesara.assert_op import Assert
26+
from aesara.tensor.var import TensorVariable
27+
from aesara.tensor.random.op import RandomVariable
2528
from aesara.tensor.random.basic import (
2629
BetaRV,
2730
cauchy,
2831
exponential,
2932
gamma,
33+
gumbel,
3034
halfcauchy,
3135
halfnormal,
3236
invgamma,
@@ -258,11 +262,7 @@ def logcdf(value, lower, upper):
258262
return at.switch(
259263
at.lt(value, lower) | at.lt(upper, lower),
260264
-np.inf,
261-
at.switch(
262-
at.lt(value, upper),
263-
at.log(value - lower) - at.log(upper - lower),
264-
0,
265-
),
265+
at.switch(at.lt(value, upper), at.log(value - lower) - at.log(upper - lower), 0,),
266266
)
267267

268268

@@ -496,10 +496,7 @@ def logcdf(value, mu, sigma):
496496
-------
497497
TensorVariable
498498
"""
499-
return bound(
500-
normal_lcdf(mu, sigma, value),
501-
0 < sigma,
502-
)
499+
return bound(normal_lcdf(mu, sigma, value), 0 < sigma,)
503500

504501

505502
class TruncatedNormal(BoundedContinuous):
@@ -833,11 +830,7 @@ def logcdf(value, loc, sigma):
833830
TensorVariable
834831
"""
835832
z = zvalue(value, mu=loc, sigma=sigma)
836-
return bound(
837-
at.log1p(-at.erfc(z / at.sqrt(2.0))),
838-
loc <= value,
839-
0 < sigma,
840-
)
833+
return bound(at.log1p(-at.erfc(z / at.sqrt(2.0))), loc <= value, 0 < sigma,)
841834

842835
def _distr_parameters_for_repr(self):
843836
return ["sigma"]
@@ -1053,11 +1046,7 @@ def logcdf(self, value):
10531046
b = 2.0 / l + normal_lcdf(0, 1, -(q + 1.0) / r)
10541047

10551048
return bound(
1056-
at.switch(
1057-
at.lt(value, np.inf),
1058-
a + log1pexp(b - a),
1059-
0,
1060-
),
1049+
at.switch(at.lt(value, np.inf), a + log1pexp(b - a), 0,),
10611050
0 < value,
10621051
0 < mu,
10631052
0 < lam,
@@ -1219,11 +1208,7 @@ def logcdf(value, alpha, beta):
12191208
)
12201209

12211210
return bound(
1222-
at.switch(
1223-
at.lt(value, 1),
1224-
at.log(incomplete_beta(alpha, beta, value)),
1225-
0,
1226-
),
1211+
at.switch(at.lt(value, 1), at.log(incomplete_beta(alpha, beta, value)), 0,),
12271212
0 <= value,
12281213
0 < alpha,
12291214
0 < beta,
@@ -1418,11 +1403,7 @@ def logcdf(value, lam):
14181403
TensorVariable
14191404
"""
14201405
a = lam * value
1421-
return bound(
1422-
log1mexp(a),
1423-
0 <= value,
1424-
0 <= lam,
1425-
)
1406+
return bound(log1mexp(a), 0 <= value, 0 <= lam,)
14261407

14271408

14281409
class Laplace(Continuous):
@@ -1538,11 +1519,7 @@ def logcdf(self, value):
15381519
at.switch(
15391520
at.le(value, a),
15401521
at.log(0.5) + y,
1541-
at.switch(
1542-
at.gt(y, 1),
1543-
at.log1p(-0.5 * at.exp(-y)),
1544-
at.log(1 - 0.5 * at.exp(-y)),
1545-
),
1522+
at.switch(at.gt(y, 1), at.log1p(-0.5 * at.exp(-y)), at.log(1 - 0.5 * at.exp(-y)),),
15461523
),
15471524
0 < b,
15481525
)
@@ -1799,11 +1776,7 @@ def logcdf(self, value):
17991776
sigma = self.sigma
18001777
tau = self.tau
18011778

1802-
return bound(
1803-
normal_lcdf(mu, sigma, at.log(value)),
1804-
0 < value,
1805-
0 < tau,
1806-
)
1779+
return bound(normal_lcdf(mu, sigma, at.log(value)), 0 < value, 0 < tau,)
18071780

18081781

18091782
class StudentT(Continuous):
@@ -1967,12 +1940,7 @@ def logcdf(self, value):
19671940
sqrt_t2_nu = at.sqrt(t ** 2 + nu)
19681941
x = (t + sqrt_t2_nu) / (2.0 * sqrt_t2_nu)
19691942

1970-
return bound(
1971-
at.log(incomplete_beta(nu / 2.0, nu / 2.0, x)),
1972-
0 < nu,
1973-
0 < sigma,
1974-
0 < lam,
1975-
)
1943+
return bound(at.log(incomplete_beta(nu / 2.0, nu / 2.0, x)), 0 < nu, 0 < sigma, 0 < lam,)
19761944

19771945

19781946
class Pareto(Continuous):
@@ -2107,11 +2075,7 @@ def logcdf(self, value):
21072075
alpha = self.alpha
21082076
arg = (m / value) ** alpha
21092077
return bound(
2110-
at.switch(
2111-
at.le(arg, 1e-5),
2112-
at.log1p(-arg),
2113-
at.log(1 - arg),
2114-
),
2078+
at.switch(at.le(arg, 1e-5), at.log1p(-arg), at.log(1 - arg),),
21152079
m <= value,
21162080
0 < alpha,
21172081
0 < m,
@@ -2209,10 +2173,7 @@ def logcdf(value, alpha, beta):
22092173
-------
22102174
TensorVariable
22112175
"""
2212-
return bound(
2213-
at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi),
2214-
0 < beta,
2215-
)
2176+
return bound(at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi), 0 < beta,)
22162177

22172178

22182179
class HalfCauchy(PositiveContinuous):
@@ -2296,11 +2257,7 @@ def logcdf(value, loc, beta):
22962257
-------
22972258
TensorVariable
22982259
"""
2299-
return bound(
2300-
at.log(2 * at.arctan((value - loc) / beta) / np.pi),
2301-
loc <= value,
2302-
0 < beta,
2303-
)
2260+
return bound(at.log(2 * at.arctan((value - loc) / beta) / np.pi), loc <= value, 0 < beta,)
23042261

23052262

23062263
class Gamma(PositiveContinuous):
@@ -2768,12 +2725,7 @@ def logcdf(self, value):
27682725
alpha = self.alpha
27692726
beta = self.beta
27702727
a = (value / beta) ** alpha
2771-
return bound(
2772-
log1mexp(a),
2773-
0 <= value,
2774-
0 < alpha,
2775-
0 < beta,
2776-
)
2728+
return bound(log1mexp(a), 0 <= value, 0 < alpha, 0 < beta,)
27772729

27782730

27792731
class HalfStudentT(PositiveContinuous):
@@ -3532,43 +3484,29 @@ class Gumbel(Continuous):
35323484
beta: float
35333485
Scale parameter (beta > 0).
35343486
"""
3487+
rv_op = gumbel
35353488

3536-
def __init__(self, mu=0, beta=1.0, **kwargs):
3537-
self.mu = at.as_tensor_variable(floatX(mu))
3538-
self.beta = at.as_tensor_variable(floatX(beta))
3539-
3540-
assert_negative_support(beta, "beta", "Gumbel")
3541-
3542-
self.mean = self.mu + self.beta * np.euler_gamma
3543-
self.median = self.mu - self.beta * at.log(at.log(2))
3544-
self.mode = self.mu
3545-
self.variance = (np.pi ** 2 / 6.0) * self.beta ** 2
3489+
@classmethod
3490+
def dist(
3491+
cls, mu: float = None, beta: float = None, no_assert: bool = False, **kwargs
3492+
) -> RandomVariable:
35463493

3547-
super().__init__(**kwargs)
3494+
mu = at.as_tensor_variable(floatX(mu))
3495+
beta = at.as_tensor_variable(floatX(beta))
35483496

3549-
def random(self, point=None, size=None):
3550-
"""
3551-
Draw random values from Gumbel distribution.
3497+
if not no_assert:
3498+
assert_negative_support(beta, "beta", "Gumbel")
35523499

3553-
Parameters
3554-
----------
3555-
point: dict, optional
3556-
Dict of variable values on which random values are to be
3557-
conditioned (uses default point if not specified).
3558-
size: int, optional
3559-
Desired size of random sample (returns one sample if not
3560-
specified).
3500+
return super().dist([mu, beta], **kwargs)
35613501

3562-
Returns
3563-
-------
3564-
array
3565-
"""
3566-
# mu, sigma = draw_values([self.mu, self.beta], point=point, size=size)
3567-
# return generate_samples(
3568-
# stats.gumbel_r.rvs, loc=mu, scale=sigma, dist_shape=self.shape, size=size
3569-
# )
3502+
def _distr_parameters_for_repr(self):
3503+
return ["mu", "beta"]
35703504

3571-
def logp(self, value):
3505+
def logp(
3506+
value: Union[float, np.ndarray, TensorVariable],
3507+
mu: Union[float, np.ndarray, TensorVariable],
3508+
beta: Union[float, np.ndarray, TensorVariable],
3509+
) -> TensorVariable:
35723510
"""
35733511
Calculate log-probability of Gumbel distribution at specified value.
35743512
@@ -3582,15 +3520,14 @@ def logp(self, value):
35823520
-------
35833521
TensorVariable
35843522
"""
3585-
mu = self.mu
3586-
beta = self.beta
35873523
scaled = (value - mu) / beta
3588-
return bound(
3589-
-scaled - at.exp(-scaled) - at.log(self.beta),
3590-
0 < beta,
3591-
)
3524+
return bound(-scaled - at.exp(-scaled) - at.log(beta), 0 < beta,)
35923525

3593-
def logcdf(self, value):
3526+
def logcdf(
3527+
value: Union[float, np.ndarray, TensorVariable],
3528+
mu: Union[float, np.ndarray, TensorVariable],
3529+
beta: Union[float, np.ndarray, TensorVariable],
3530+
) -> TensorVariable:
35943531
"""
35953532
Compute the log of the cumulative distribution function for Gumbel distribution
35963533
at the specified value.
@@ -3605,13 +3542,7 @@ def logcdf(self, value):
36053542
-------
36063543
TensorVariable
36073544
"""
3608-
beta = self.beta
3609-
mu = self.mu
3610-
3611-
return bound(
3612-
-at.exp(-(value - mu) / beta),
3613-
0 < beta,
3614-
)
3545+
return bound(-at.exp(-(value - mu) / beta), 0 < beta,)
36153546

36163547

36173548
class Rice(PositiveContinuous):
@@ -3870,8 +3801,7 @@ def logp(self, value):
38703801
s = self.s
38713802

38723803
return bound(
3873-
-(value - mu) / s - at.log(s) - 2 * at.log1p(at.exp(-(value - mu) / s)),
3874-
s > 0,
3804+
-(value - mu) / s - at.log(s) - 2 * at.log1p(at.exp(-(value - mu) / s)), s > 0,
38753805
)
38763806

38773807
def logcdf(self, value):
@@ -3891,10 +3821,7 @@ def logcdf(self, value):
38913821
"""
38923822
mu = self.mu
38933823
s = self.s
3894-
return bound(
3895-
-log1pexp(-(value - mu) / s),
3896-
0 < s,
3897-
)
3824+
return bound(-log1pexp(-(value - mu) / s), 0 < s,)
38983825

38993826

39003827
class LogitNormal(UnitContinuous):
@@ -4253,7 +4180,5 @@ def logcdf(self, value):
42534180
sigma = self.sigma
42544181

42554182
scaled = (value - mu) / sigma
4256-
return bound(
4257-
at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))),
4258-
0 < sigma,
4259-
)
4183+
return bound(at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))), 0 < sigma,)
4184+

0 commit comments

Comments
 (0)