Skip to content

Commit da457f2

Browse files
ricardoV94brandonwillard
authored andcommitted
Refactor Beta to use custom rng_fn clipped_beta_rv
1 parent 4a88c17 commit da457f2

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

pymc3/distributions/continuous.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from aesara.assert_op import Assert
2525
from aesara.tensor.random.basic import (
26-
beta,
26+
BetaRV,
2727
cauchy,
2828
exponential,
2929
gamma,
@@ -42,6 +42,7 @@
4242
SplineWrapper,
4343
betaln,
4444
bound,
45+
clipped_beta_rvs,
4546
gammaln,
4647
i0e,
4748
incomplete_beta,
@@ -1075,6 +1076,15 @@ def logcdf(self, value):
10751076
)
10761077

10771078

1079+
class BetaClippedRV(BetaRV):
1080+
@classmethod
1081+
def rng_fn(cls, rng, alpha, beta, size):
1082+
return clipped_beta_rvs(alpha, beta, size=size, random_state=rng)
1083+
1084+
1085+
beta = BetaClippedRV()
1086+
1087+
10781088
class Beta(UnitContinuous):
10791089
r"""
10801090
Beta log-likelihood.
@@ -1149,9 +1159,6 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar
11491159
alpha = aet.as_tensor_variable(floatX(alpha))
11501160
beta = aet.as_tensor_variable(floatX(beta))
11511161

1152-
# mean = alpha / (alpha + beta)
1153-
# variance = (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1))
1154-
11551162
assert_negative_support(alpha, "alpha", "Beta")
11561163
assert_negative_support(beta, "beta", "Beta")
11571164

pymc3/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def incomplete_beta(a, b, value):
600600
)
601601

602602

603-
def clipped_beta_rvs(a, b, size=None, dtype="float64"):
603+
def clipped_beta_rvs(a, b, size=None, random_state=None, dtype="float64"):
604604
"""Draw beta distributed random samples in the open :math:`(0, 1)` interval.
605605
606606
The samples are generated with ``scipy.stats.beta.rvs``, but any value that
@@ -635,6 +635,6 @@ def clipped_beta_rvs(a, b, size=None, dtype="float64"):
635635
is shifted to ``np.nextafter(1, 0, dtype=dtype)``.
636636
637637
"""
638-
out = scipy.stats.beta.rvs(a, b, size=size).astype(dtype)
638+
out = scipy.stats.beta.rvs(a, b, size=size, random_state=random_state).astype(dtype)
639639
lower, upper = _beta_clip_values[dtype]
640640
return np.maximum(np.minimum(out, upper), lower)

pymc3/tests/test_distributions_random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ class TestWald(BaseTestCases.BaseTestCase):
301301
params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0}
302302

303303

304-
@pytest.mark.skip(reason="This test is covered by Aesara")
305304
class TestBeta(BaseTestCases.BaseTestCase):
306305
distribution = pm.Beta
307306
params = {"alpha": 1.0, "beta": 1.0}

0 commit comments

Comments
 (0)