Skip to content

Commit e5c42b4

Browse files
fonnesbeckbrandonwillard
authored andcommitted
Converted Pareto distribution to v4
1 parent 1285ac7 commit e5c42b4

File tree

3 files changed

+19
-32
lines changed

3 files changed

+19
-32
lines changed

pymc3/distributions/continuous.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
halfnormal,
3535
invgamma,
3636
normal,
37+
pareto,
3738
uniform,
3839
)
3940
from aesara.tensor.random.op import RandomVariable
@@ -2029,23 +2030,19 @@ class Pareto(Continuous):
20292030
m: float
20302031
Scale parameter (m > 0).
20312032
"""
2033+
rv_op = pareto
20322034

2033-
def __init__(self, alpha, m, transform="lowerbound", *args, **kwargs):
2034-
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
2035-
self.m = m = at.as_tensor_variable(floatX(m))
2036-
2037-
self.mean = at.switch(at.gt(alpha, 1), alpha * m / (alpha - 1.0), np.inf)
2038-
self.median = m * 2.0 ** (1.0 / alpha)
2039-
self.variance = at.switch(
2040-
at.gt(alpha, 2), (alpha * m ** 2) / ((alpha - 2.0) * (alpha - 1.0) ** 2), np.inf
2041-
)
2035+
@classmethod
2036+
def dist(
2037+
cls, alpha: float = None, m: float = None, no_assert: bool = False, **kwargs
2038+
) -> RandomVariable:
2039+
alpha = at.as_tensor_variable(floatX(alpha))
2040+
m = at.as_tensor_variable(floatX(m))
20422041

20432042
assert_negative_support(alpha, "alpha", "Pareto")
20442043
assert_negative_support(m, "m", "Pareto")
20452044

2046-
if transform == "lowerbound":
2047-
transform = transforms.lowerbound(self.m)
2048-
super().__init__(transform=transform, *args, **kwargs)
2045+
return super().dist([alpha, m], **kwargs)
20492046

20502047
def _random(self, alpha, m, size=None):
20512048
u = np.random.uniform(size=size)
@@ -2071,7 +2068,11 @@ def random(self, point=None, size=None):
20712068
# alpha, m = draw_values([self.alpha, self.m], point=point, size=size)
20722069
# return generate_samples(self._random, alpha, m, dist_shape=self.shape, size=size)
20732070

2074-
def logp(self, value):
2071+
def logp(
2072+
value: Union[float, np.ndarray, TensorVariable],
2073+
alpha: Union[float, np.ndarray, TensorVariable],
2074+
m: Union[float, np.ndarray, TensorVariable],
2075+
):
20752076
"""
20762077
Calculate log-probability of Pareto distribution at specified value.
20772078
@@ -2085,8 +2086,6 @@ def logp(self, value):
20852086
-------
20862087
TensorVariable
20872088
"""
2088-
alpha = self.alpha
2089-
m = self.m
20902089
return bound(
20912090
at.log(alpha) + logpow(m, alpha) - logpow(value, alpha + 1),
20922091
value >= m,
@@ -2097,7 +2096,11 @@ def logp(self, value):
20972096
def _distr_parameters_for_repr(self):
20982097
return ["alpha", "m"]
20992098

2100-
def logcdf(self, value):
2099+
def logcdf(
2100+
value: Union[float, np.ndarray, TensorVariable],
2101+
alpha: Union[float, np.ndarray, TensorVariable],
2102+
m: Union[float, np.ndarray, TensorVariable],
2103+
):
21012104
"""
21022105
Compute the log of the cumulative distribution function for Pareto distribution
21032106
at the specified value.
@@ -2112,8 +2115,6 @@ def logcdf(self, value):
21122115
-------
21132116
TensorVariable
21142117
"""
2115-
m = self.m
2116-
alpha = self.alpha
21172118
arg = (m / value) ** alpha
21182119
return bound(
21192120
at.switch(

pymc3/tests/test_distributions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,6 @@ def test_fun(value, mu, sigma):
14021402
decimal=select_by_precision(float64=4, float32=3),
14031403
)
14041404

1405-
@pytest.mark.xfail(reason="Distribution not refactored yet")
14061405
def test_pareto(self):
14071406
self.check_logp(
14081407
Pareto,

pymc3/tests/test_distributions_random.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,6 @@ class TestStudentT(BaseTestCases.BaseTestCase):
342342
params = {"nu": 5.0, "mu": 0.0, "lam": 1.0}
343343

344344

345-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
346-
class TestPareto(BaseTestCases.BaseTestCase):
347-
distribution = pm.Pareto
348-
params = {"alpha": 0.5, "m": 1.0}
349-
350-
351345
@pytest.mark.skip(reason="This test is covered by Aesara")
352346
class TestCauchy(BaseTestCases.BaseTestCase):
353347
distribution = pm.Cauchy
@@ -681,13 +675,6 @@ def ref_rand(size, alpha, beta):
681675

682676
pymc3_random(pm.InverseGamma, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
683677

684-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
685-
def test_pareto(self):
686-
def ref_rand(size, alpha, m):
687-
return st.pareto.rvs(alpha, scale=m, size=size)
688-
689-
pymc3_random(pm.Pareto, {"alpha": Rplusbig, "m": Rplusbig}, ref_rand=ref_rand)
690-
691678
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
692679
def test_ex_gaussian(self):
693680
def ref_rand(size, mu, sigma, nu):

0 commit comments

Comments
 (0)