Skip to content

Commit c85f7fc

Browse files
authored
Black formatted 15 files (#4113)
* Black formatted 15 files * Ran pyupgrade on all files
1 parent 56ccabb commit c85f7fc

15 files changed

+1875
-1778
lines changed

pymc3/tests/test_distributions.py

Lines changed: 36 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def __mul__(self, other):
157157
)
158158

159159
def __neg__(self):
160-
return Domain(
161-
[-v for v in self.vals], self.dtype, (-self.lower, -self.upper), self.shape
162-
)
160+
return Domain([-v for v in self.vals], self.dtype, (-self.lower, -self.upper), self.shape)
163161

164162

165163
def product(domains, n_samples=-1):
@@ -177,9 +175,7 @@ def product(domains, n_samples=-1):
177175
names, domains = zip(*domains.items())
178176
except ValueError: # domains.items() is empty
179177
return []
180-
all_vals = [
181-
zip(names, val) for val in itertools.product(*[d.vals for d in domains])
182-
]
178+
all_vals = [zip(names, val) for val in itertools.product(*[d.vals for d in domains])]
183179
if n_samples > 0 and len(all_vals) > n_samples:
184180
return (all_vals[j] for j in nr.choice(len(all_vals), n_samples, replace=False))
185181
return all_vals
@@ -428,9 +424,7 @@ def invlogit(x, eps=sys.float_info.epsilon):
428424

429425
def orderedlogistic_logpdf(value, eta, cutpoints):
430426
c = np.concatenate(([-np.inf], cutpoints, [np.inf]))
431-
ps = np.array(
432-
[invlogit(eta - cc) - invlogit(eta - cc1) for cc, cc1 in zip(c[:-1], c[1:])]
433-
)
427+
ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1) for cc, cc1 in zip(c[:-1], c[1:])])
434428
p = ps[value]
435429
return np.where(np.all(ps >= 0), np.log(p), -np.inf)
436430

@@ -445,9 +439,7 @@ def __init__(self, n):
445439
class MultiSimplex:
446440
def __init__(self, n_dependent, n_independent):
447441
self.vals = []
448-
for simplex_value in itertools.product(
449-
simplex_values(n_dependent), repeat=n_independent
450-
):
442+
for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent):
451443
self.vals.append(np.vstack(simplex_value))
452444
self.shape = (n_independent, n_dependent)
453445
self.dtype = Unit.dtype
@@ -468,16 +460,12 @@ def PdMatrix(n):
468460

469461
PdMatrix2 = Domain([np.eye(2), [[0.5, 0.05], [0.05, 4.5]]], edges=(None, None))
470462

471-
PdMatrix3 = Domain(
472-
[np.eye(3), [[0.5, 0.1, 0], [0.1, 1, 0], [0, 0, 2.5]]], edges=(None, None)
473-
)
463+
PdMatrix3 = Domain([np.eye(3), [[0.5, 0.1, 0], [0.1, 1, 0], [0, 0, 2.5]]], edges=(None, None))
474464

475465

476466
PdMatrixChol1 = Domain([np.eye(1), [[0.001]]], edges=(None, None))
477467
PdMatrixChol2 = Domain([np.eye(2), [[0.1, 0], [10, 1]]], edges=(None, None))
478-
PdMatrixChol3 = Domain(
479-
[np.eye(3), [[0.1, 0, 0], [10, 100, 0], [0, 1, 10]]], edges=(None, None)
480-
)
468+
PdMatrixChol3 = Domain([np.eye(3), [[0.1, 0, 0], [10, 100, 0], [0, 1, 10]]], edges=(None, None))
481469

482470

483471
def PdMatrixChol(n):
@@ -538,19 +526,15 @@ def logp(args):
538526

539527
self.check_logp(model, value, domain, paramdomains, logp, decimal=decimal)
540528

541-
def check_logp(
542-
self, model, value, domain, paramdomains, logp_reference, decimal=None
543-
):
529+
def check_logp(self, model, value, domain, paramdomains, logp_reference, decimal=None):
544530
domains = paramdomains.copy()
545531
domains["value"] = domain
546532
logp = model.fastlogp
547533
for pt in product(domains, n_samples=100):
548534
pt = Point(pt, model=model)
549535
if decimal is None:
550536
decimal = select_by_precision(float64=6, float32=3)
551-
assert_almost_equal(
552-
logp(pt), logp_reference(pt), decimal=decimal, err_msg=str(pt)
553-
)
537+
assert_almost_equal(logp(pt), logp_reference(pt), decimal=decimal, err_msg=str(pt))
554538

555539
def check_logcdf(
556540
self,
@@ -615,17 +599,13 @@ def test_triangular(self):
615599
Triangular,
616600
Runif,
617601
{"lower": -Rplusunif, "c": Runif, "upper": Rplusunif},
618-
lambda value, c, lower, upper: sp.triang.logpdf(
619-
value, c - lower, lower, upper - lower
620-
),
602+
lambda value, c, lower, upper: sp.triang.logpdf(value, c - lower, lower, upper - lower),
621603
)
622604
self.check_logcdf(
623605
Triangular,
624606
Runif,
625607
{"lower": -Rplusunif, "c": Runif, "upper": Rplusunif},
626-
lambda value, c, lower, upper: sp.triang.logcdf(
627-
value, c - lower, lower, upper - lower
628-
),
608+
lambda value, c, lower, upper: sp.triang.logcdf(value, c - lower, lower, upper - lower),
629609
)
630610

631611
def test_bound_normal(self):
@@ -774,9 +754,7 @@ def test_beta(self):
774754
{"alpha": Rplus, "beta": Rplus},
775755
lambda value, alpha, beta: sp.beta.logpdf(value, alpha, beta),
776756
)
777-
self.pymc3_matches_scipy(
778-
Beta, Unit, {"mu": Unit, "sigma": Rplus}, beta_mu_sigma
779-
)
757+
self.pymc3_matches_scipy(Beta, Unit, {"mu": Unit, "sigma": Rplus}, beta_mu_sigma)
780758
self.check_logcdf(
781759
Beta,
782760
Unit,
@@ -788,15 +766,10 @@ def test_kumaraswamy(self):
788766
# Scipy does not have a built-in Kumaraswamy pdf
789767
def scipy_log_pdf(value, a, b):
790768
return (
791-
np.log(a)
792-
+ np.log(b)
793-
+ (a - 1) * np.log(value)
794-
+ (b - 1) * np.log(1 - value ** a)
769+
np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value ** a)
795770
)
796771

797-
self.pymc3_matches_scipy(
798-
Kumaraswamy, Unit, {"a": Rplus, "b": Rplus}, scipy_log_pdf
799-
)
772+
self.pymc3_matches_scipy(Kumaraswamy, Unit, {"a": Rplus, "b": Rplus}, scipy_log_pdf)
800773

801774
def test_exponential(self):
802775
self.pymc3_matches_scipy(
@@ -821,9 +794,7 @@ def test_negative_binomial(self):
821794
def test_fun(value, mu, alpha):
822795
return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha))
823796

824-
self.pymc3_matches_scipy(
825-
NegativeBinomial, Nat, {"mu": Rplus, "alpha": Rplus}, test_fun
826-
)
797+
self.pymc3_matches_scipy(NegativeBinomial, Nat, {"mu": Rplus, "alpha": Rplus}, test_fun)
827798

828799
def test_laplace(self):
829800
self.pymc3_matches_scipy(
@@ -844,9 +815,7 @@ def test_lognormal(self):
844815
Lognormal,
845816
Rplus,
846817
{"mu": R, "tau": Rplusbig},
847-
lambda value, mu, tau: floatX(
848-
sp.lognorm.logpdf(value, tau ** -0.5, 0, np.exp(mu))
849-
),
818+
lambda value, mu, tau: floatX(sp.lognorm.logpdf(value, tau ** -0.5, 0, np.exp(mu))),
850819
)
851820
self.check_logcdf(
852821
Lognormal,
@@ -907,13 +876,9 @@ def test_gamma(self):
907876
)
908877

909878
def test_fun(value, mu, sigma):
910-
return sp.gamma.logpdf(
911-
value, mu ** 2 / sigma ** 2, scale=1.0 / (mu / sigma ** 2)
912-
)
879+
return sp.gamma.logpdf(value, mu ** 2 / sigma ** 2, scale=1.0 / (mu / sigma ** 2))
913880

914-
self.pymc3_matches_scipy(
915-
Gamma, Rplus, {"mu": Rplusbig, "sigma": Rplusbig}, test_fun
916-
)
881+
self.pymc3_matches_scipy(Gamma, Rplus, {"mu": Rplusbig, "sigma": Rplusbig}, test_fun)
917882

918883
self.check_logcdf(
919884
Gamma,
@@ -939,9 +904,7 @@ def test_fun(value, mu, sigma):
939904
alpha, beta = InverseGamma._get_alpha_beta(None, None, mu, sigma)
940905
return sp.invgamma.logpdf(value, alpha, scale=beta)
941906

942-
self.pymc3_matches_scipy(
943-
InverseGamma, Rplus, {"mu": Rplus, "sigma": Rplus}, test_fun
944-
)
907+
self.pymc3_matches_scipy(InverseGamma, Rplus, {"mu": Rplus, "sigma": Rplus}, test_fun)
945908

946909
def test_pareto(self):
947910
self.pymc3_matches_scipy(
@@ -1001,9 +964,7 @@ def test_binomial(self):
1001964
)
1002965

1003966
# Too lazy to propagate decimal parameter through the whole chain of deps
1004-
@pytest.mark.xfail(
1005-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1006-
)
967+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1007968
def test_beta_binomial(self):
1008969
self.checkd(BetaBinomial, Nat, {"alpha": Rplus, "beta": Rplus, "n": NatSmall})
1009970

@@ -1012,9 +973,7 @@ def test_bernoulli(self):
1012973
Bernoulli,
1013974
Bool,
1014975
{"logit_p": R},
1015-
lambda value, logit_p: sp.bernoulli.logpmf(
1016-
value, scipy.special.expit(logit_p)
1017-
),
976+
lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)),
1018977
)
1019978
self.pymc3_matches_scipy(
1020979
Bernoulli, Bool, {"p": Unit}, lambda value, p: sp.bernoulli.logpmf(value, p)
@@ -1047,21 +1006,15 @@ def test_bound_poisson(self):
10471006
assert np.isinf(x.logp({"x": 0}))
10481007

10491008
def test_constantdist(self):
1050-
self.pymc3_matches_scipy(
1051-
Constant, I, {"c": I}, lambda value, c: np.log(c == value)
1052-
)
1009+
self.pymc3_matches_scipy(Constant, I, {"c": I}, lambda value, c: np.log(c == value))
10531010

10541011
# Too lazy to propagate decimal parameter through the whole chain of deps
1055-
@pytest.mark.xfail(
1056-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1057-
)
1012+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
10581013
def test_zeroinflatedpoisson(self):
10591014
self.checkd(ZeroInflatedPoisson, Nat, {"theta": Rplus, "psi": Unit})
10601015

10611016
# Too lazy to propagate decimal parameter through the whole chain of deps
1062-
@pytest.mark.xfail(
1063-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1064-
)
1017+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
10651018
def test_zeroinflatednegativebinomial(self):
10661019
self.checkd(
10671020
ZeroInflatedNegativeBinomial,
@@ -1070,9 +1023,7 @@ def test_zeroinflatednegativebinomial(self):
10701023
)
10711024

10721025
# Too lazy to propagate decimal parameter through the whole chain of deps
1073-
@pytest.mark.xfail(
1074-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1075-
)
1026+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
10761027
def test_zeroinflatedbinomial(self):
10771028
self.checkd(ZeroInflatedBinomial, Nat, {"n": NatSmall, "p": Unit, "psi": Unit})
10781029

@@ -1298,9 +1249,7 @@ def test_mvt(self, n):
12981249

12991250
@pytest.mark.parametrize("n", [2, 3, 4])
13001251
def test_AR1(self, n):
1301-
self.pymc3_matches_scipy(
1302-
AR1, Vector(R, n), {"k": Unit, "tau_e": Rplus}, AR1_logpdf
1303-
)
1252+
self.pymc3_matches_scipy(AR1, Vector(R, n), {"k": Unit, "tau_e": Rplus}, AR1_logpdf)
13041253

13051254
@pytest.mark.parametrize("n", [2, 3])
13061255
def test_wishart(self, n):
@@ -1325,9 +1274,7 @@ def test_lkj(self, x, eta, n, lp):
13251274

13261275
@pytest.mark.parametrize("n", [2, 3])
13271276
def test_dirichlet(self, n):
1328-
self.pymc3_matches_scipy(
1329-
Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf
1330-
)
1277+
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
13311278

13321279
def test_dirichlet_shape(self):
13331280
a = tt.as_tensor_variable(np.r_[1, 2])
@@ -1529,9 +1476,7 @@ def logp(x):
15291476

15301477
def test_get_tau_sigma(self):
15311478
sigma = np.array([2])
1532-
assert_almost_equal(
1533-
continuous.get_tau_sigma(sigma=sigma), [1.0 / sigma ** 2, sigma]
1534-
)
1479+
assert_almost_equal(continuous.get_tau_sigma(sigma=sigma), [1.0 / sigma ** 2, sigma])
15351480

15361481
@pytest.mark.parametrize(
15371482
"value,mu,sigma,nu,logp",
@@ -1582,9 +1527,7 @@ def test_ex_gaussian_cdf(self, value, mu, sigma, nu, logcdf):
15821527
err_msg=str((value, mu, sigma, nu, logcdf)),
15831528
)
15841529

1585-
@pytest.mark.xfail(
1586-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1587-
)
1530+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
15881531
def test_vonmises(self):
15891532
self.pymc3_matches_scipy(
15901533
VonMises,
@@ -1626,8 +1569,7 @@ def test_logitnormal(self):
16261569
Unit,
16271570
{"mu": R, "sigma": Rplus},
16281571
lambda value, mu, sigma: (
1629-
sp.norm.logpdf(logit(value), mu, sigma)
1630-
- (np.log(value) + np.log1p(-value))
1572+
sp.norm.logpdf(logit(value), mu, sigma) - (np.log(value) + np.log1p(-value))
16311573
),
16321574
decimal=select_by_precision(float64=6, float32=1),
16331575
)
@@ -1641,9 +1583,7 @@ def test_rice(self):
16411583
Rice,
16421584
Rplus,
16431585
{"nu": Rplus, "sigma": Rplusbig},
1644-
lambda value, nu, sigma: sp.rice.logpdf(
1645-
value, b=nu / sigma, loc=0, scale=sigma
1646-
),
1586+
lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma),
16471587
)
16481588
self.pymc3_matches_scipy(
16491589
Rice,
@@ -1652,9 +1592,7 @@ def test_rice(self):
16521592
lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma),
16531593
)
16541594

1655-
@pytest.mark.xfail(
1656-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1657-
)
1595+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
16581596
def test_moyal(self):
16591597
self.pymc3_matches_scipy(
16601598
Moyal,
@@ -1669,9 +1607,7 @@ def test_moyal(self):
16691607
lambda value, mu, sigma: floatX(sp.moyal.logcdf(value, mu, sigma)),
16701608
)
16711609

1672-
@pytest.mark.xfail(
1673-
condition=(theano.config.floatX == "float32"), reason="Fails on float32"
1674-
)
1610+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
16751611
def test_interpolated(self):
16761612
for mu in R.vals:
16771613
for sigma in Rplus.vals:
@@ -1683,9 +1619,7 @@ class TestedInterpolated(Interpolated):
16831619
def __init__(self, **kwargs):
16841620
x_points = np.linspace(xmin, xmax, 100000)
16851621
pdf_points = sp.norm.pdf(x_points, loc=mu, scale=sigma)
1686-
super().__init__(
1687-
x_points=x_points, pdf_points=pdf_points, **kwargs
1688-
)
1622+
super().__init__(x_points=x_points, pdf_points=pdf_points, **kwargs)
16891623

16901624
def ref_pdf(value):
16911625
return np.where(
@@ -1896,9 +1830,10 @@ def func(x):
18961830
return -2 * (x ** 2).sum()
18971831

18981832
with pm.Model():
1899-
pm.Normal('x')
1900-
y = pm.DensityDist('y', func)
1833+
pm.Normal("x")
1834+
y = pm.DensityDist("y", func)
19011835
pm.sample(draws=5, tune=1, mp_ctx="spawn")
19021836

19031837
import pickle
1838+
19041839
pickle.loads(pickle.dumps(y))

0 commit comments

Comments
 (0)