Skip to content

Commit 199451d

Browse files
Change tests for more refactored distributions.
More details on commit id 0773620b6f599423315035b97ef082ad32d98fd4
1 parent c6f2f31 commit 199451d

File tree

2 files changed

+51
-51
lines changed

2 files changed

+51
-51
lines changed

pymc3/distributions/discrete.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,24 +731,24 @@ def NegBinom(a, m, x):
731731

732732
@classmethod
733733
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
734-
n, p = cls.get_mu_alpha(mu, alpha, p, n)
734+
n, p = cls.get_n_p(mu, alpha, p, n)
735735
n = aet.as_tensor_variable(floatX(n))
736736
p = aet.as_tensor_variable(floatX(p))
737737
return super().dist([n, p], *args, **kwargs)
738738

739739
@classmethod
740-
def get_mu_alpha(cls, mu=None, alpha=None, p=None, n=None):
740+
def get_n_p(cls, mu=None, alpha=None, p=None, n=None):
741741
if n is None:
742742
if alpha is not None:
743-
n = aet.as_tensor_variable(floatX(alpha))
743+
n = alpha
744744
else:
745745
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
746746
elif alpha is not None:
747747
raise ValueError("Incompatible parametrization. Can't specify both alpha and n.")
748748

749749
if p is None:
750750
if mu is not None:
751-
mu = aet.as_tensor_variable(floatX(mu))
751+
mu = mu
752752
p = n / (mu + n)
753753
else:
754754
raise ValueError("Incompatible parametrization. Must specify either mu or p.")

pymc3/tests/test_distributions_random.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,7 @@ def get_inputs_from_apply_node_outputs(outputs):
550550
# I am assuming there will always only be 1 Apply parent node in this context
551551
return parents[0].inputs
552552

553-
def test_pymc_params_match_rv_ones(
554-
self, pymc_params, expected_aesara_params, pymc_dist, decimal=6
555-
):
553+
def _pymc_params_match_rv_ones(self, pymc_params, expected_aesara_params, pymc_dist, decimal=6):
556554
pymc_dist_output = pymc_dist.dist(**dict(pymc_params))
557555
aesera_dist_inputs = self.get_inputs_from_apply_node_outputs(pymc_dist_output)[3:]
558556
assert len(expected_aesara_params) == len(aesera_dist_inputs)
@@ -563,52 +561,88 @@ def test_pymc_params_match_rv_ones(
563561

564562
def test_normal(self):
565563
params = [("mu", 5.0), ("sigma", 10.0)]
566-
self.test_pymc_params_match_rv_ones(params, params, pm.Normal)
564+
self._pymc_params_match_rv_ones(params, params, pm.Normal)
567565

568566
def test_uniform(self):
569567
params = [("lower", 0.5), ("upper", 1.5)]
570-
self.test_pymc_params_match_rv_ones(params, params, pm.Uniform)
568+
self._pymc_params_match_rv_ones(params, params, pm.Uniform)
571569

572570
def test_half_normal(self):
573571
params, expected_aesara_params = [("sigma", 10.0)], [("mean", 0), ("sigma", 10.0)]
574-
self.test_pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
572+
self._pymc_params_match_rv_ones(params, expected_aesara_params, pm.HalfNormal)
575573

576574
def test_beta_alpha_beta(self):
577575
params = [("alpha", 2.0), ("beta", 5.0)]
578-
self.test_pymc_params_match_rv_ones(params, params, pm.Beta)
576+
self._pymc_params_match_rv_ones(params, params, pm.Beta)
579577

580578
def test_beta_mu_sigma(self):
581579
params = [("mu", 2.0), ("sigma", 5.0)]
582580
expected_alpha, expected_beta = pm.Beta.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
583581
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
584-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Beta)
582+
self._pymc_params_match_rv_ones(params, expected_params, pm.Beta)
585583

586584
@pytest.mark.skip(reason="Expected to fail due to bug")
587585
def test_exponential(self):
588586
params = [("lam", 10.0)]
589587
expected_params = [("lam", 1 / params[0][1])]
590-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
588+
self._pymc_params_match_rv_ones(params, expected_params, pm.Exponential)
591589

592590
def test_cauchy(self):
593591
params = [("alpha", 2.0), ("beta", 5.0)]
594-
self.test_pymc_params_match_rv_ones(params, params, pm.Cauchy)
592+
self._pymc_params_match_rv_ones(params, params, pm.Cauchy)
595593

596594
def test_half_cauchy(self):
597595
params = [("alpha", 2.0), ("beta", 5.0)]
598-
self.test_pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
596+
self._pymc_params_match_rv_ones(params, params, pm.HalfCauchy)
599597

600598
@pytest.mark.skip(reason="Expected to fail due to bug")
601599
def test_gamma_alpha_beta(self):
602600
params = [("alpha", 2.0), ("beta", 5.0)]
603601
expected_params = [("alpha", params[0][1]), ("beta", 1 / params[1][1])]
604-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
602+
self._pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
605603

606604
@pytest.mark.skip(reason="Expected to fail due to bug")
607605
def test_gamma_mu_sigma(self):
608606
params = [("mu", 2.0), ("sigma", 5.0)]
609607
expected_alpha, expected_beta = pm.Gamma.get_alpha_beta(mu=params[0][1], sigma=params[1][1])
610608
expected_params = [("alpha", expected_alpha), ("beta", 1 / expected_beta)]
611-
self.test_pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
609+
self._pymc_params_match_rv_ones(params, expected_params, pm.Gamma)
610+
611+
def test_inverse_gamma_alpha_beta(self):
612+
params = [("alpha", 2.0), ("beta", 5.0)]
613+
self._pymc_params_match_rv_ones(params, params, pm.InverseGamma)
614+
615+
def test_inverse_gamma_mu_sigma(self):
616+
params = [("mu", 2.0), ("sigma", 5.0)]
617+
expected_alpha, expected_beta = pm.InverseGamma._get_alpha_beta(
618+
mu=params[0][1], sigma=params[1][1], alpha=None, beta=None
619+
)
620+
expected_params = [("alpha", expected_alpha), ("beta", expected_beta)]
621+
self._pymc_params_match_rv_ones(params, expected_params, pm.InverseGamma)
622+
623+
def test_binomial(self):
624+
params = [("n", 100), ("p", 0.33)]
625+
self._pymc_params_match_rv_ones(params, params, pm.Binomial)
626+
627+
def test_negative_binomial(self):
628+
params = [("n", 100), ("p", 0.33)]
629+
self._pymc_params_match_rv_ones(params, params, pm.NegativeBinomial)
630+
631+
def test_negative_binomial_mu_sigma(self):
632+
params = [("mu", 5.0), ("alpha", 8.0)]
633+
expected_n, expected_p = pm.NegativeBinomial.get_n_p(
634+
mu=params[0][1], alpha=params[1][1], n=None, p=None
635+
)
636+
expected_params = [("n", expected_n), ("p", expected_p)]
637+
self._pymc_params_match_rv_ones(params, expected_params, pm.NegativeBinomial)
638+
639+
def test_bernoulli(self):
640+
params = [("p", 0.33)]
641+
self._pymc_params_match_rv_ones(params, params, pm.Bernoulli)
642+
643+
def test_poisson(self):
644+
params = [("mu", 4)]
645+
self._pymc_params_match_rv_ones(params, params, pm.Poisson)
612646

613647

614648
class TestScalarParameterSamples(SeededTest):
@@ -706,13 +740,6 @@ def ref_rand(size, nu, mu, lam):
706740

707741
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
708742

709-
@pytest.mark.skip(reason="This test is covered by Aesara")
710-
def test_inverse_gamma(self):
711-
def ref_rand(size, alpha, beta):
712-
return st.invgamma.rvs(a=alpha, scale=beta, size=size)
713-
714-
pymc3_random(pm.InverseGamma, {"alpha": Rplus, "beta": Rplus}, ref_rand=ref_rand)
715-
716743
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
717744
def test_pareto(self):
718745
def ref_rand(size, alpha, m):
@@ -759,10 +786,6 @@ def test_half_flat(self):
759786
with pytest.raises(ValueError):
760787
f.random(1)
761788

762-
@pytest.mark.skip(reason="This test is covered by Aesara")
763-
def test_binomial(self):
764-
pymc3_random_discrete(pm.Binomial, {"n": Nat, "p": Unit}, ref_rand=st.binom.rvs)
765-
766789
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
767790
@pytest.mark.xfail(
768791
sys.platform.startswith("win"),
@@ -776,29 +799,6 @@ def test_beta_binomial(self):
776799
def _beta_bin(self, n, alpha, beta, size=None):
777800
return st.binom.rvs(n, st.beta.rvs(a=alpha, b=beta, size=size))
778801

779-
@pytest.mark.skip(reason="This test is covered by Aesara")
780-
def test_bernoulli(self):
781-
pymc3_random_discrete(
782-
pm.Bernoulli, {"p": Unit}, ref_rand=lambda size, p=None: st.bernoulli.rvs(p, size=size)
783-
)
784-
785-
@pytest.mark.skip(reason="This test is covered by Aesara")
786-
def test_poisson(self):
787-
pymc3_random_discrete(pm.Poisson, {"mu": Rplusbig}, size=500, ref_rand=st.poisson.rvs)
788-
789-
@pytest.mark.skip(reason="This test is covered by Aesara")
790-
def test_negative_binomial(self):
791-
def ref_rand(size, alpha, mu):
792-
return st.nbinom.rvs(alpha, alpha / (mu + alpha), size=size)
793-
794-
pymc3_random_discrete(
795-
pm.NegativeBinomial,
796-
{"mu": Rplusbig, "alpha": Rplusbig},
797-
size=100,
798-
fails=50,
799-
ref_rand=ref_rand,
800-
)
801-
802802
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
803803
def test_geometric(self):
804804
pymc3_random_discrete(pm.Geometric, {"p": Unit}, size=500, fails=50, ref_rand=nr.geometric)

0 commit comments

Comments
 (0)