@@ -157,9 +157,7 @@ def __mul__(self, other):
157
157
)
158
158
159
159
def __neg__ (self ):
160
- return Domain (
161
- [- v for v in self .vals ], self .dtype , (- self .lower , - self .upper ), self .shape
162
- )
160
+ return Domain ([- v for v in self .vals ], self .dtype , (- self .lower , - self .upper ), self .shape )
163
161
164
162
165
163
def product (domains , n_samples = - 1 ):
@@ -177,9 +175,7 @@ def product(domains, n_samples=-1):
177
175
names , domains = zip (* domains .items ())
178
176
except ValueError : # domains.items() is empty
179
177
return []
180
- all_vals = [
181
- zip (names , val ) for val in itertools .product (* [d .vals for d in domains ])
182
- ]
178
+ all_vals = [zip (names , val ) for val in itertools .product (* [d .vals for d in domains ])]
183
179
if n_samples > 0 and len (all_vals ) > n_samples :
184
180
return (all_vals [j ] for j in nr .choice (len (all_vals ), n_samples , replace = False ))
185
181
return all_vals
@@ -428,9 +424,7 @@ def invlogit(x, eps=sys.float_info.epsilon):
428
424
429
425
def orderedlogistic_logpdf (value , eta , cutpoints ):
430
426
c = np .concatenate (([- np .inf ], cutpoints , [np .inf ]))
431
- ps = np .array (
432
- [invlogit (eta - cc ) - invlogit (eta - cc1 ) for cc , cc1 in zip (c [:- 1 ], c [1 :])]
433
- )
427
+ ps = np .array ([invlogit (eta - cc ) - invlogit (eta - cc1 ) for cc , cc1 in zip (c [:- 1 ], c [1 :])])
434
428
p = ps [value ]
435
429
return np .where (np .all (ps >= 0 ), np .log (p ), - np .inf )
436
430
@@ -445,9 +439,7 @@ def __init__(self, n):
445
439
class MultiSimplex :
446
440
def __init__ (self , n_dependent , n_independent ):
447
441
self .vals = []
448
- for simplex_value in itertools .product (
449
- simplex_values (n_dependent ), repeat = n_independent
450
- ):
442
+ for simplex_value in itertools .product (simplex_values (n_dependent ), repeat = n_independent ):
451
443
self .vals .append (np .vstack (simplex_value ))
452
444
self .shape = (n_independent , n_dependent )
453
445
self .dtype = Unit .dtype
@@ -468,16 +460,12 @@ def PdMatrix(n):
468
460
469
461
PdMatrix2 = Domain ([np .eye (2 ), [[0.5 , 0.05 ], [0.05 , 4.5 ]]], edges = (None , None ))
470
462
471
- PdMatrix3 = Domain (
472
- [np .eye (3 ), [[0.5 , 0.1 , 0 ], [0.1 , 1 , 0 ], [0 , 0 , 2.5 ]]], edges = (None , None )
473
- )
463
+ PdMatrix3 = Domain ([np .eye (3 ), [[0.5 , 0.1 , 0 ], [0.1 , 1 , 0 ], [0 , 0 , 2.5 ]]], edges = (None , None ))
474
464
475
465
476
466
PdMatrixChol1 = Domain ([np .eye (1 ), [[0.001 ]]], edges = (None , None ))
477
467
PdMatrixChol2 = Domain ([np .eye (2 ), [[0.1 , 0 ], [10 , 1 ]]], edges = (None , None ))
478
- PdMatrixChol3 = Domain (
479
- [np .eye (3 ), [[0.1 , 0 , 0 ], [10 , 100 , 0 ], [0 , 1 , 10 ]]], edges = (None , None )
480
- )
468
+ PdMatrixChol3 = Domain ([np .eye (3 ), [[0.1 , 0 , 0 ], [10 , 100 , 0 ], [0 , 1 , 10 ]]], edges = (None , None ))
481
469
482
470
483
471
def PdMatrixChol (n ):
@@ -538,19 +526,15 @@ def logp(args):
538
526
539
527
self .check_logp (model , value , domain , paramdomains , logp , decimal = decimal )
540
528
541
- def check_logp (
542
- self , model , value , domain , paramdomains , logp_reference , decimal = None
543
- ):
529
+ def check_logp (self , model , value , domain , paramdomains , logp_reference , decimal = None ):
544
530
domains = paramdomains .copy ()
545
531
domains ["value" ] = domain
546
532
logp = model .fastlogp
547
533
for pt in product (domains , n_samples = 100 ):
548
534
pt = Point (pt , model = model )
549
535
if decimal is None :
550
536
decimal = select_by_precision (float64 = 6 , float32 = 3 )
551
- assert_almost_equal (
552
- logp (pt ), logp_reference (pt ), decimal = decimal , err_msg = str (pt )
553
- )
537
+ assert_almost_equal (logp (pt ), logp_reference (pt ), decimal = decimal , err_msg = str (pt ))
554
538
555
539
def check_logcdf (
556
540
self ,
@@ -615,17 +599,13 @@ def test_triangular(self):
615
599
Triangular ,
616
600
Runif ,
617
601
{"lower" : - Rplusunif , "c" : Runif , "upper" : Rplusunif },
618
- lambda value , c , lower , upper : sp .triang .logpdf (
619
- value , c - lower , lower , upper - lower
620
- ),
602
+ lambda value , c , lower , upper : sp .triang .logpdf (value , c - lower , lower , upper - lower ),
621
603
)
622
604
self .check_logcdf (
623
605
Triangular ,
624
606
Runif ,
625
607
{"lower" : - Rplusunif , "c" : Runif , "upper" : Rplusunif },
626
- lambda value , c , lower , upper : sp .triang .logcdf (
627
- value , c - lower , lower , upper - lower
628
- ),
608
+ lambda value , c , lower , upper : sp .triang .logcdf (value , c - lower , lower , upper - lower ),
629
609
)
630
610
631
611
def test_bound_normal (self ):
@@ -774,9 +754,7 @@ def test_beta(self):
774
754
{"alpha" : Rplus , "beta" : Rplus },
775
755
lambda value , alpha , beta : sp .beta .logpdf (value , alpha , beta ),
776
756
)
777
- self .pymc3_matches_scipy (
778
- Beta , Unit , {"mu" : Unit , "sigma" : Rplus }, beta_mu_sigma
779
- )
757
+ self .pymc3_matches_scipy (Beta , Unit , {"mu" : Unit , "sigma" : Rplus }, beta_mu_sigma )
780
758
self .check_logcdf (
781
759
Beta ,
782
760
Unit ,
@@ -788,15 +766,10 @@ def test_kumaraswamy(self):
788
766
# Scipy does not have a built-in Kumaraswamy pdf
789
767
def scipy_log_pdf (value , a , b ):
790
768
return (
791
- np .log (a )
792
- + np .log (b )
793
- + (a - 1 ) * np .log (value )
794
- + (b - 1 ) * np .log (1 - value ** a )
769
+ np .log (a ) + np .log (b ) + (a - 1 ) * np .log (value ) + (b - 1 ) * np .log (1 - value ** a )
795
770
)
796
771
797
- self .pymc3_matches_scipy (
798
- Kumaraswamy , Unit , {"a" : Rplus , "b" : Rplus }, scipy_log_pdf
799
- )
772
+ self .pymc3_matches_scipy (Kumaraswamy , Unit , {"a" : Rplus , "b" : Rplus }, scipy_log_pdf )
800
773
801
774
def test_exponential (self ):
802
775
self .pymc3_matches_scipy (
@@ -821,9 +794,7 @@ def test_negative_binomial(self):
821
794
def test_fun (value , mu , alpha ):
822
795
return sp .nbinom .logpmf (value , alpha , 1 - mu / (mu + alpha ))
823
796
824
- self .pymc3_matches_scipy (
825
- NegativeBinomial , Nat , {"mu" : Rplus , "alpha" : Rplus }, test_fun
826
- )
797
+ self .pymc3_matches_scipy (NegativeBinomial , Nat , {"mu" : Rplus , "alpha" : Rplus }, test_fun )
827
798
828
799
def test_laplace (self ):
829
800
self .pymc3_matches_scipy (
@@ -844,9 +815,7 @@ def test_lognormal(self):
844
815
Lognormal ,
845
816
Rplus ,
846
817
{"mu" : R , "tau" : Rplusbig },
847
- lambda value , mu , tau : floatX (
848
- sp .lognorm .logpdf (value , tau ** - 0.5 , 0 , np .exp (mu ))
849
- ),
818
+ lambda value , mu , tau : floatX (sp .lognorm .logpdf (value , tau ** - 0.5 , 0 , np .exp (mu ))),
850
819
)
851
820
self .check_logcdf (
852
821
Lognormal ,
@@ -907,13 +876,9 @@ def test_gamma(self):
907
876
)
908
877
909
878
def test_fun (value , mu , sigma ):
910
- return sp .gamma .logpdf (
911
- value , mu ** 2 / sigma ** 2 , scale = 1.0 / (mu / sigma ** 2 )
912
- )
879
+ return sp .gamma .logpdf (value , mu ** 2 / sigma ** 2 , scale = 1.0 / (mu / sigma ** 2 ))
913
880
914
- self .pymc3_matches_scipy (
915
- Gamma , Rplus , {"mu" : Rplusbig , "sigma" : Rplusbig }, test_fun
916
- )
881
+ self .pymc3_matches_scipy (Gamma , Rplus , {"mu" : Rplusbig , "sigma" : Rplusbig }, test_fun )
917
882
918
883
self .check_logcdf (
919
884
Gamma ,
@@ -939,9 +904,7 @@ def test_fun(value, mu, sigma):
939
904
alpha , beta = InverseGamma ._get_alpha_beta (None , None , mu , sigma )
940
905
return sp .invgamma .logpdf (value , alpha , scale = beta )
941
906
942
- self .pymc3_matches_scipy (
943
- InverseGamma , Rplus , {"mu" : Rplus , "sigma" : Rplus }, test_fun
944
- )
907
+ self .pymc3_matches_scipy (InverseGamma , Rplus , {"mu" : Rplus , "sigma" : Rplus }, test_fun )
945
908
946
909
def test_pareto (self ):
947
910
self .pymc3_matches_scipy (
@@ -1001,9 +964,7 @@ def test_binomial(self):
1001
964
)
1002
965
1003
966
# Too lazy to propagate decimal parameter through the whole chain of deps
1004
- @pytest .mark .xfail (
1005
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1006
- )
967
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1007
968
def test_beta_binomial (self ):
1008
969
self .checkd (BetaBinomial , Nat , {"alpha" : Rplus , "beta" : Rplus , "n" : NatSmall })
1009
970
@@ -1012,9 +973,7 @@ def test_bernoulli(self):
1012
973
Bernoulli ,
1013
974
Bool ,
1014
975
{"logit_p" : R },
1015
- lambda value , logit_p : sp .bernoulli .logpmf (
1016
- value , scipy .special .expit (logit_p )
1017
- ),
976
+ lambda value , logit_p : sp .bernoulli .logpmf (value , scipy .special .expit (logit_p )),
1018
977
)
1019
978
self .pymc3_matches_scipy (
1020
979
Bernoulli , Bool , {"p" : Unit }, lambda value , p : sp .bernoulli .logpmf (value , p )
@@ -1047,21 +1006,15 @@ def test_bound_poisson(self):
1047
1006
assert np .isinf (x .logp ({"x" : 0 }))
1048
1007
1049
1008
def test_constantdist (self ):
1050
- self .pymc3_matches_scipy (
1051
- Constant , I , {"c" : I }, lambda value , c : np .log (c == value )
1052
- )
1009
+ self .pymc3_matches_scipy (Constant , I , {"c" : I }, lambda value , c : np .log (c == value ))
1053
1010
1054
1011
# Too lazy to propagate decimal parameter through the whole chain of deps
1055
- @pytest .mark .xfail (
1056
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1057
- )
1012
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1058
1013
def test_zeroinflatedpoisson (self ):
1059
1014
self .checkd (ZeroInflatedPoisson , Nat , {"theta" : Rplus , "psi" : Unit })
1060
1015
1061
1016
# Too lazy to propagate decimal parameter through the whole chain of deps
1062
- @pytest .mark .xfail (
1063
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1064
- )
1017
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1065
1018
def test_zeroinflatednegativebinomial (self ):
1066
1019
self .checkd (
1067
1020
ZeroInflatedNegativeBinomial ,
@@ -1070,9 +1023,7 @@ def test_zeroinflatednegativebinomial(self):
1070
1023
)
1071
1024
1072
1025
# Too lazy to propagate decimal parameter through the whole chain of deps
1073
- @pytest .mark .xfail (
1074
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1075
- )
1026
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1076
1027
def test_zeroinflatedbinomial (self ):
1077
1028
self .checkd (ZeroInflatedBinomial , Nat , {"n" : NatSmall , "p" : Unit , "psi" : Unit })
1078
1029
@@ -1298,9 +1249,7 @@ def test_mvt(self, n):
1298
1249
1299
1250
@pytest .mark .parametrize ("n" , [2 , 3 , 4 ])
1300
1251
def test_AR1 (self , n ):
1301
- self .pymc3_matches_scipy (
1302
- AR1 , Vector (R , n ), {"k" : Unit , "tau_e" : Rplus }, AR1_logpdf
1303
- )
1252
+ self .pymc3_matches_scipy (AR1 , Vector (R , n ), {"k" : Unit , "tau_e" : Rplus }, AR1_logpdf )
1304
1253
1305
1254
@pytest .mark .parametrize ("n" , [2 , 3 ])
1306
1255
def test_wishart (self , n ):
@@ -1325,9 +1274,7 @@ def test_lkj(self, x, eta, n, lp):
1325
1274
1326
1275
@pytest .mark .parametrize ("n" , [2 , 3 ])
1327
1276
def test_dirichlet (self , n ):
1328
- self .pymc3_matches_scipy (
1329
- Dirichlet , Simplex (n ), {"a" : Vector (Rplus , n )}, dirichlet_logpdf
1330
- )
1277
+ self .pymc3_matches_scipy (Dirichlet , Simplex (n ), {"a" : Vector (Rplus , n )}, dirichlet_logpdf )
1331
1278
1332
1279
def test_dirichlet_shape (self ):
1333
1280
a = tt .as_tensor_variable (np .r_ [1 , 2 ])
@@ -1529,9 +1476,7 @@ def logp(x):
1529
1476
1530
1477
def test_get_tau_sigma (self ):
1531
1478
sigma = np .array ([2 ])
1532
- assert_almost_equal (
1533
- continuous .get_tau_sigma (sigma = sigma ), [1.0 / sigma ** 2 , sigma ]
1534
- )
1479
+ assert_almost_equal (continuous .get_tau_sigma (sigma = sigma ), [1.0 / sigma ** 2 , sigma ])
1535
1480
1536
1481
@pytest .mark .parametrize (
1537
1482
"value,mu,sigma,nu,logp" ,
@@ -1582,9 +1527,7 @@ def test_ex_gaussian_cdf(self, value, mu, sigma, nu, logcdf):
1582
1527
err_msg = str ((value , mu , sigma , nu , logcdf )),
1583
1528
)
1584
1529
1585
- @pytest .mark .xfail (
1586
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1587
- )
1530
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1588
1531
def test_vonmises (self ):
1589
1532
self .pymc3_matches_scipy (
1590
1533
VonMises ,
@@ -1626,8 +1569,7 @@ def test_logitnormal(self):
1626
1569
Unit ,
1627
1570
{"mu" : R , "sigma" : Rplus },
1628
1571
lambda value , mu , sigma : (
1629
- sp .norm .logpdf (logit (value ), mu , sigma )
1630
- - (np .log (value ) + np .log1p (- value ))
1572
+ sp .norm .logpdf (logit (value ), mu , sigma ) - (np .log (value ) + np .log1p (- value ))
1631
1573
),
1632
1574
decimal = select_by_precision (float64 = 6 , float32 = 1 ),
1633
1575
)
@@ -1641,9 +1583,7 @@ def test_rice(self):
1641
1583
Rice ,
1642
1584
Rplus ,
1643
1585
{"nu" : Rplus , "sigma" : Rplusbig },
1644
- lambda value , nu , sigma : sp .rice .logpdf (
1645
- value , b = nu / sigma , loc = 0 , scale = sigma
1646
- ),
1586
+ lambda value , nu , sigma : sp .rice .logpdf (value , b = nu / sigma , loc = 0 , scale = sigma ),
1647
1587
)
1648
1588
self .pymc3_matches_scipy (
1649
1589
Rice ,
@@ -1652,9 +1592,7 @@ def test_rice(self):
1652
1592
lambda value , b , sigma : sp .rice .logpdf (value , b = b , loc = 0 , scale = sigma ),
1653
1593
)
1654
1594
1655
- @pytest .mark .xfail (
1656
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1657
- )
1595
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1658
1596
def test_moyal (self ):
1659
1597
self .pymc3_matches_scipy (
1660
1598
Moyal ,
@@ -1669,9 +1607,7 @@ def test_moyal(self):
1669
1607
lambda value , mu , sigma : floatX (sp .moyal .logcdf (value , mu , sigma )),
1670
1608
)
1671
1609
1672
- @pytest .mark .xfail (
1673
- condition = (theano .config .floatX == "float32" ), reason = "Fails on float32"
1674
- )
1610
+ @pytest .mark .xfail (condition = (theano .config .floatX == "float32" ), reason = "Fails on float32" )
1675
1611
def test_interpolated (self ):
1676
1612
for mu in R .vals :
1677
1613
for sigma in Rplus .vals :
@@ -1683,9 +1619,7 @@ class TestedInterpolated(Interpolated):
1683
1619
def __init__ (self , ** kwargs ):
1684
1620
x_points = np .linspace (xmin , xmax , 100000 )
1685
1621
pdf_points = sp .norm .pdf (x_points , loc = mu , scale = sigma )
1686
- super ().__init__ (
1687
- x_points = x_points , pdf_points = pdf_points , ** kwargs
1688
- )
1622
+ super ().__init__ (x_points = x_points , pdf_points = pdf_points , ** kwargs )
1689
1623
1690
1624
def ref_pdf (value ):
1691
1625
return np .where (
@@ -1896,9 +1830,10 @@ def func(x):
1896
1830
return - 2 * (x ** 2 ).sum ()
1897
1831
1898
1832
with pm .Model ():
1899
- pm .Normal ('x' )
1900
- y = pm .DensityDist ('y' , func )
1833
+ pm .Normal ("x" )
1834
+ y = pm .DensityDist ("y" , func )
1901
1835
pm .sample (draws = 5 , tune = 1 , mp_ctx = "spawn" )
1902
1836
1903
1837
import pickle
1838
+
1904
1839
pickle .loads (pickle .dumps (y ))
0 commit comments