Skip to content

Commit 8873806

Browse files
fonnesbecktwiecki
authored andcommitted
Ran pre-commit scripts
1 parent c1323b1 commit 8873806

File tree

1 file changed

+87
-22
lines changed

1 file changed

+87
-22
lines changed

pymc3/distributions/continuous.py

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
nodes in PyMC.
1919
"""
2020

21+
from typing import Union
22+
2123
import aesara.tensor as at
2224
import numpy as np
23-
from typing import Union
2425

2526
from aesara.assert_op import Assert
26-
from aesara.tensor.var import TensorVariable
27-
from aesara.tensor.random.op import RandomVariable
2827
from aesara.tensor.random.basic import (
2928
BetaRV,
3029
cauchy,
@@ -37,6 +36,8 @@
3736
normal,
3837
uniform,
3938
)
39+
from aesara.tensor.random.op import RandomVariable
40+
from aesara.tensor.var import TensorVariable
4041
from scipy import stats
4142
from scipy.interpolate import InterpolatedUnivariateSpline
4243

@@ -262,7 +263,11 @@ def logcdf(value, lower, upper):
262263
return at.switch(
263264
at.lt(value, lower) | at.lt(upper, lower),
264265
-np.inf,
265-
at.switch(at.lt(value, upper), at.log(value - lower) - at.log(upper - lower), 0,),
266+
at.switch(
267+
at.lt(value, upper),
268+
at.log(value - lower) - at.log(upper - lower),
269+
0,
270+
),
266271
)
267272

268273

@@ -496,7 +501,10 @@ def logcdf(value, mu, sigma):
496501
-------
497502
TensorVariable
498503
"""
499-
return bound(normal_lcdf(mu, sigma, value), 0 < sigma,)
504+
return bound(
505+
normal_lcdf(mu, sigma, value),
506+
0 < sigma,
507+
)
500508

501509

502510
class TruncatedNormal(BoundedContinuous):
@@ -830,7 +838,11 @@ def logcdf(value, loc, sigma):
830838
TensorVariable
831839
"""
832840
z = zvalue(value, mu=loc, sigma=sigma)
833-
return bound(at.log1p(-at.erfc(z / at.sqrt(2.0))), loc <= value, 0 < sigma,)
841+
return bound(
842+
at.log1p(-at.erfc(z / at.sqrt(2.0))),
843+
loc <= value,
844+
0 < sigma,
845+
)
834846

835847
def _distr_parameters_for_repr(self):
836848
return ["sigma"]
@@ -1046,7 +1058,11 @@ def logcdf(self, value):
10461058
b = 2.0 / l + normal_lcdf(0, 1, -(q + 1.0) / r)
10471059

10481060
return bound(
1049-
at.switch(at.lt(value, np.inf), a + log1pexp(b - a), 0,),
1061+
at.switch(
1062+
at.lt(value, np.inf),
1063+
a + log1pexp(b - a),
1064+
0,
1065+
),
10501066
0 < value,
10511067
0 < mu,
10521068
0 < lam,
@@ -1208,7 +1224,11 @@ def logcdf(value, alpha, beta):
12081224
)
12091225

12101226
return bound(
1211-
at.switch(at.lt(value, 1), at.log(incomplete_beta(alpha, beta, value)), 0,),
1227+
at.switch(
1228+
at.lt(value, 1),
1229+
at.log(incomplete_beta(alpha, beta, value)),
1230+
0,
1231+
),
12121232
0 <= value,
12131233
0 < alpha,
12141234
0 < beta,
@@ -1403,7 +1423,11 @@ def logcdf(value, lam):
14031423
TensorVariable
14041424
"""
14051425
a = lam * value
1406-
return bound(log1mexp(a), 0 <= value, 0 <= lam,)
1426+
return bound(
1427+
log1mexp(a),
1428+
0 <= value,
1429+
0 <= lam,
1430+
)
14071431

14081432

14091433
class Laplace(Continuous):
@@ -1519,7 +1543,11 @@ def logcdf(self, value):
15191543
at.switch(
15201544
at.le(value, a),
15211545
at.log(0.5) + y,
1522-
at.switch(at.gt(y, 1), at.log1p(-0.5 * at.exp(-y)), at.log(1 - 0.5 * at.exp(-y)),),
1546+
at.switch(
1547+
at.gt(y, 1),
1548+
at.log1p(-0.5 * at.exp(-y)),
1549+
at.log(1 - 0.5 * at.exp(-y)),
1550+
),
15231551
),
15241552
0 < b,
15251553
)
@@ -1776,7 +1804,11 @@ def logcdf(self, value):
17761804
sigma = self.sigma
17771805
tau = self.tau
17781806

1779-
return bound(normal_lcdf(mu, sigma, at.log(value)), 0 < value, 0 < tau,)
1807+
return bound(
1808+
normal_lcdf(mu, sigma, at.log(value)),
1809+
0 < value,
1810+
0 < tau,
1811+
)
17801812

17811813

17821814
class StudentT(Continuous):
@@ -1940,7 +1972,12 @@ def logcdf(self, value):
19401972
sqrt_t2_nu = at.sqrt(t ** 2 + nu)
19411973
x = (t + sqrt_t2_nu) / (2.0 * sqrt_t2_nu)
19421974

1943-
return bound(at.log(incomplete_beta(nu / 2.0, nu / 2.0, x)), 0 < nu, 0 < sigma, 0 < lam,)
1975+
return bound(
1976+
at.log(incomplete_beta(nu / 2.0, nu / 2.0, x)),
1977+
0 < nu,
1978+
0 < sigma,
1979+
0 < lam,
1980+
)
19441981

19451982

19461983
class Pareto(Continuous):
@@ -2075,7 +2112,11 @@ def logcdf(self, value):
20752112
alpha = self.alpha
20762113
arg = (m / value) ** alpha
20772114
return bound(
2078-
at.switch(at.le(arg, 1e-5), at.log1p(-arg), at.log(1 - arg),),
2115+
at.switch(
2116+
at.le(arg, 1e-5),
2117+
at.log1p(-arg),
2118+
at.log(1 - arg),
2119+
),
20792120
m <= value,
20802121
0 < alpha,
20812122
0 < m,
@@ -2173,7 +2214,10 @@ def logcdf(value, alpha, beta):
21732214
-------
21742215
TensorVariable
21752216
"""
2176-
return bound(at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi), 0 < beta,)
2217+
return bound(
2218+
at.log(0.5 + at.arctan((value - alpha) / beta) / np.pi),
2219+
0 < beta,
2220+
)
21772221

21782222

21792223
class HalfCauchy(PositiveContinuous):
@@ -2257,7 +2301,11 @@ def logcdf(value, loc, beta):
22572301
-------
22582302
TensorVariable
22592303
"""
2260-
return bound(at.log(2 * at.arctan((value - loc) / beta) / np.pi), loc <= value, 0 < beta,)
2304+
return bound(
2305+
at.log(2 * at.arctan((value - loc) / beta) / np.pi),
2306+
loc <= value,
2307+
0 < beta,
2308+
)
22612309

22622310

22632311
class Gamma(PositiveContinuous):
@@ -2725,7 +2773,12 @@ def logcdf(self, value):
27252773
alpha = self.alpha
27262774
beta = self.beta
27272775
a = (value / beta) ** alpha
2728-
return bound(log1mexp(a), 0 <= value, 0 < alpha, 0 < beta,)
2776+
return bound(
2777+
log1mexp(a),
2778+
0 <= value,
2779+
0 < alpha,
2780+
0 < beta,
2781+
)
27292782

27302783

27312784
class HalfStudentT(PositiveContinuous):
@@ -3521,7 +3574,10 @@ def logp(
35213574
TensorVariable
35223575
"""
35233576
scaled = (value - mu) / beta
3524-
return bound(-scaled - at.exp(-scaled) - at.log(beta), 0 < beta,)
3577+
return bound(
3578+
-scaled - at.exp(-scaled) - at.log(beta),
3579+
0 < beta,
3580+
)
35253581

35263582
def logcdf(
35273583
value: Union[float, np.ndarray, TensorVariable],
@@ -3542,7 +3598,10 @@ def logcdf(
35423598
-------
35433599
TensorVariable
35443600
"""
3545-
return bound(-at.exp(-(value - mu) / beta), 0 < beta,)
3601+
return bound(
3602+
-at.exp(-(value - mu) / beta),
3603+
0 < beta,
3604+
)
35463605

35473606

35483607
class Rice(PositiveContinuous):
@@ -3801,7 +3860,8 @@ def logp(self, value):
38013860
s = self.s
38023861

38033862
return bound(
3804-
-(value - mu) / s - at.log(s) - 2 * at.log1p(at.exp(-(value - mu) / s)), s > 0,
3863+
-(value - mu) / s - at.log(s) - 2 * at.log1p(at.exp(-(value - mu) / s)),
3864+
s > 0,
38053865
)
38063866

38073867
def logcdf(self, value):
@@ -3821,7 +3881,10 @@ def logcdf(self, value):
38213881
"""
38223882
mu = self.mu
38233883
s = self.s
3824-
return bound(-log1pexp(-(value - mu) / s), 0 < s,)
3884+
return bound(
3885+
-log1pexp(-(value - mu) / s),
3886+
0 < s,
3887+
)
38253888

38263889

38273890
class LogitNormal(UnitContinuous):
@@ -4180,5 +4243,7 @@ def logcdf(self, value):
41804243
sigma = self.sigma
41814244

41824245
scaled = (value - mu) / sigma
4183-
return bound(at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))), 0 < sigma,)
4184-
4246+
return bound(
4247+
at.log(at.erfc(at.exp(-scaled / 2) * (2 ** -0.5))),
4248+
0 < sigma,
4249+
)

0 commit comments

Comments
 (0)