@@ -70,9 +70,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
70
70
else :
71
71
return super ().__str__ ()
72
72
73
- if name is None and hasattr (self , ' name' ):
73
+ if name is None and hasattr (self , " name" ):
74
74
name = self .name
75
- if dist is None and hasattr (self , ' distribution' ):
75
+ if dist is None and hasattr (self , " distribution" ):
76
76
dist = self .distribution
77
77
return self .distribution ._str_repr (name = name , dist = dist , formatting = formatting )
78
78
@@ -123,8 +123,7 @@ def incorporate_methods(source, destination, methods, wrapper=None, override=Fal
123
123
for method in methods :
124
124
if hasattr (destination , method ) and not override :
125
125
raise AttributeError (
126
- f"Cannot add method { method !r} "
127
- + "to destination object as it already exists. "
126
+ f"Cannot add method { method !r} " + "to destination object as it already exists. "
128
127
"To prevent this error set 'override=True'."
129
128
)
130
129
if hasattr (source , method ):
@@ -172,12 +171,8 @@ def get_named_nodes_and_relations(graph):
172
171
else :
173
172
ancestors = {}
174
173
descendents = {}
175
- descendents , ancestors = _get_named_nodes_and_relations (
176
- graph , None , ancestors , descendents
177
- )
178
- leaf_dict = {
179
- node .name : node for node , ancestor in ancestors .items () if len (ancestor ) == 0
180
- }
174
+ descendents , ancestors = _get_named_nodes_and_relations (graph , None , ancestors , descendents )
175
+ leaf_dict = {node .name : node for node , ancestor in ancestors .items () if len (ancestor ) == 0 }
181
176
return leaf_dict , descendents , ancestors
182
177
183
178
@@ -529,9 +524,7 @@ def tree_contains(self, item):
529
524
530
525
def __setitem__ (self , key , value ):
531
526
raise NotImplementedError (
532
- "Method is removed as we are not"
533
- " able to determine "
534
- "appropriate logic for it"
527
+ "Method is removed as we are not able to determine appropriate logic for it"
535
528
)
536
529
537
530
# Added this because mypy didn't like having __imul__ without __mul__
@@ -620,7 +613,7 @@ def __init__(
620
613
dtype = None ,
621
614
casting = "no" ,
622
615
compute_grads = True ,
623
- ** kwargs
616
+ ** kwargs ,
624
617
):
625
618
from .distributions import TensorType
626
619
@@ -695,9 +688,7 @@ def __init__(
695
688
696
689
inputs = [self ._vars_joined ]
697
690
698
- self ._theano_function = theano .function (
699
- inputs , outputs , givens = givens , ** kwargs
700
- )
691
+ self ._theano_function = theano .function (inputs , outputs , givens = givens , ** kwargs )
701
692
702
693
def set_weights (self , values ):
703
694
if values .shape != (self ._n_costs - 1 ,):
@@ -713,10 +704,7 @@ def get_extra_values(self):
713
704
if not self ._extra_are_set :
714
705
raise ValueError ("Extra values are not set." )
715
706
716
- return {
717
- var .name : self ._extra_vars_shared [var .name ].get_value ()
718
- for var in self ._extra_vars
719
- }
707
+ return {var .name : self ._extra_vars_shared [var .name ].get_value () for var in self ._extra_vars }
720
708
721
709
def __call__ (self , array , grad_out = None , extra_vars = None ):
722
710
if extra_vars is not None :
@@ -727,8 +715,7 @@ def __call__(self, array, grad_out=None, extra_vars=None):
727
715
728
716
if array .shape != (self .size ,):
729
717
raise ValueError (
730
- "Invalid shape for array. Must be %s but is %s."
731
- % ((self .size ,), array .shape )
718
+ "Invalid shape for array. Must be {} but is {}." .format ((self .size ,), array .shape )
732
719
)
733
720
734
721
if grad_out is None :
@@ -758,13 +745,10 @@ def dict_to_array(self, point):
758
745
def array_to_dict (self , array ):
759
746
"""Convert an array to a dictionary containing the grad_vars."""
760
747
if array .shape != (self .size ,):
761
- raise ValueError (
762
- f"Array should have shape ({ self .size } ,) but has { array .shape } "
763
- )
748
+ raise ValueError (f"Array should have shape ({ self .size } ,) but has { array .shape } " )
764
749
if array .dtype != self .dtype :
765
750
raise ValueError (
766
- "Array has invalid dtype. Should be %s but is %s"
767
- % (self ._dtype , self .dtype )
751
+ f"Array has invalid dtype. Should be { self ._dtype } but is { self .dtype } "
768
752
)
769
753
point = {}
770
754
for varmap in self ._ordering .vmap :
@@ -988,17 +972,15 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
988
972
for var in grad_vars :
989
973
if var .dtype not in continuous_types :
990
974
raise ValueError (
991
- "Can only compute the gradient of " " continuous types: %s" % var
975
+ "Can only compute the gradient of continuous types: %s" % var
992
976
)
993
977
994
978
if tempered :
995
979
with self :
996
- free_RVs_logp = tt .sum ([
997
- tt .sum (var .logpt ) for var in self .free_RVs + self .potentials
998
- ])
999
- observed_RVs_logp = tt .sum ([
1000
- tt .sum (var .logpt ) for var in self .observed_RVs
1001
- ])
980
+ free_RVs_logp = tt .sum (
981
+ [tt .sum (var .logpt ) for var in self .free_RVs + self .potentials ]
982
+ )
983
+ observed_RVs_logp = tt .sum ([tt .sum (var .logpt ) for var in self .observed_RVs ])
1002
984
1003
985
costs = [free_RVs_logp , observed_RVs_logp ]
1004
986
else :
@@ -1038,7 +1020,7 @@ def logp_nojact(self):
1038
1020
@property
1039
1021
def varlogpt (self ):
1040
1022
"""Theano scalar of log-probability of the unobserved random variables
1041
- (excluding deterministic)."""
1023
+ (excluding deterministic)."""
1042
1024
with self :
1043
1025
factors = [var .logpt for var in self .free_RVs ]
1044
1026
return tt .sum (factors )
@@ -1110,9 +1092,7 @@ def add_coords(self, coords):
1110
1092
)
1111
1093
if name in self .coords :
1112
1094
if not coords [name ].equals (self .coords [name ]):
1113
- raise ValueError (
1114
- "Duplicate and incompatiple coordinate: %s." % name
1115
- )
1095
+ raise ValueError ("Duplicate and incompatiple coordinate: %s." % name )
1116
1096
else :
1117
1097
self .coords [name ] = coords [name ]
1118
1098
@@ -1141,9 +1121,7 @@ def Var(self, name, dist, data=None, total_size=None, dims=None):
1141
1121
if data is None :
1142
1122
if getattr (dist , "transform" , None ) is None :
1143
1123
with self :
1144
- var = FreeRV (
1145
- name = name , distribution = dist , total_size = total_size , model = self
1146
- )
1124
+ var = FreeRV (name = name , distribution = dist , total_size = total_size , model = self )
1147
1125
self .free_RVs .append (var )
1148
1126
else :
1149
1127
with self :
@@ -1218,8 +1196,7 @@ def prefix(self):
1218
1196
return "%s_" % self .name if self .name else ""
1219
1197
1220
1198
def name_for (self , name ):
1221
- """Checks if name has prefix and adds if needed
1222
- """
1199
+ """Checks if name has prefix and adds if needed"""
1223
1200
if self .prefix :
1224
1201
if not name .startswith (self .prefix ):
1225
1202
return f"{ self .prefix } { name } "
@@ -1229,8 +1206,7 @@ def name_for(self, name):
1229
1206
return name
1230
1207
1231
1208
def name_of (self , name ):
1232
- """Checks if name has prefix and deletes if needed
1233
- """
1209
+ """Checks if name has prefix and deletes if needed"""
1234
1210
if not self .prefix or not name :
1235
1211
return name
1236
1212
elif name .startswith (self .prefix ):
@@ -1269,7 +1245,7 @@ def makefn(self, outs, mode=None, *args, **kwargs):
1269
1245
accept_inplace = True ,
1270
1246
mode = mode ,
1271
1247
* args ,
1272
- ** kwargs
1248
+ ** kwargs ,
1273
1249
)
1274
1250
1275
1251
def fn (self , outs , mode = None , * args , ** kwargs ):
@@ -1391,10 +1367,7 @@ def check_test_point(self, test_point=None, round_vals=2):
1391
1367
test_point = self .test_point
1392
1368
1393
1369
return Series (
1394
- {
1395
- RV .name : np .round (RV .logp (self .test_point ), round_vals )
1396
- for RV in self .basic_RVs
1397
- },
1370
+ {RV .name : np .round (RV .logp (self .test_point ), round_vals ) for RV in self .basic_RVs },
1398
1371
name = "Log-probability of test_point" ,
1399
1372
)
1400
1373
@@ -1403,23 +1376,31 @@ def _str_repr(self, formatting="plain", **kwargs):
1403
1376
1404
1377
if formatting == "latex" :
1405
1378
rv_reprs = [rv .__latex__ () for rv in all_rv ]
1406
- rv_reprs = [rv_repr .replace (r"\sim" , r"&\sim &" ).strip ("$" )
1407
- for rv_repr in rv_reprs if rv_repr is not None ]
1379
+ rv_reprs = [
1380
+ rv_repr .replace (r"\sim" , r"&\sim &" ).strip ("$" )
1381
+ for rv_repr in rv_reprs
1382
+ if rv_repr is not None
1383
+ ]
1408
1384
return r"""$$
1409
1385
\begin{{array}}{{rcl}}
1410
1386
{}
1411
1387
\end{{array}}
1412
1388
$$""" .format (
1413
- "\\ \\ " .join (rv_reprs ))
1389
+ "\\ \\ " .join (rv_reprs )
1390
+ )
1414
1391
else :
1415
1392
rv_reprs = [rv .__str__ () for rv in all_rv ]
1416
- rv_reprs = [rv_repr for rv_repr in rv_reprs if not 'TransformedDistribution()' in rv_repr ]
1393
+ rv_reprs = [
1394
+ rv_repr for rv_repr in rv_reprs if not "TransformedDistribution()" in rv_repr
1395
+ ]
1417
1396
# align vars on their ~
1418
- names = [s [:s .index ('~' ) - 1 ] for s in rv_reprs ]
1419
- distrs = [s [s .index ('~' ) + 2 :] for s in rv_reprs ]
1397
+ names = [s [: s .index ("~" ) - 1 ] for s in rv_reprs ]
1398
+ distrs = [s [s .index ("~" ) + 2 :] for s in rv_reprs ]
1420
1399
maxlen = str (max (len (x ) for x in names ))
1421
- rv_reprs = [('{name:>' + maxlen + '} ~ {distr}' ).format (name = n , distr = d )
1422
- for n , d in zip (names , distrs )]
1400
+ rv_reprs = [
1401
+ ("{name:>" + maxlen + "} ~ {distr}" ).format (name = n , distr = d )
1402
+ for n , d in zip (names , distrs )
1403
+ ]
1423
1404
return "\n " .join (rv_reprs )
1424
1405
1425
1406
def __str__ (self , ** kwargs ):
@@ -1537,8 +1518,9 @@ def Point(*args, **kwargs):
1537
1518
except Exception as e :
1538
1519
raise TypeError (f"can't turn { args } and { kwargs } into a dict. { e } " )
1539
1520
return {
1540
- get_var_name (k ): np .array (v ) for k , v in d .items ()
1541
- if get_var_name (k ) in map (get_var_name , model .vars )
1521
+ get_var_name (k ): np .array (v )
1522
+ for k , v in d .items ()
1523
+ if get_var_name (k ) in map (get_var_name , model .vars )
1542
1524
}
1543
1525
1544
1526
@@ -1593,11 +1575,7 @@ def _get_scaling(total_size, shape, ndim):
1593
1575
denom = 1
1594
1576
coef = floatX (total_size ) / floatX (denom )
1595
1577
elif isinstance (total_size , (list , tuple )):
1596
- if not all (
1597
- isinstance (i , int )
1598
- for i in total_size
1599
- if (i is not Ellipsis and i is not None )
1600
- ):
1578
+ if not all (isinstance (i , int ) for i in total_size if (i is not Ellipsis and i is not None )):
1601
1579
raise TypeError (
1602
1580
"Unrecognized `total_size` type, expected "
1603
1581
"int or list of ints, got %r" % total_size
@@ -1625,16 +1603,13 @@ def _get_scaling(total_size, shape, ndim):
1625
1603
else :
1626
1604
shp_end = np .asarray ([])
1627
1605
shp_begin = shape [: len (begin )]
1628
- begin_coef = [
1629
- floatX (t ) / shp_begin [i ] for i , t in enumerate (begin ) if t is not None
1630
- ]
1606
+ begin_coef = [floatX (t ) / shp_begin [i ] for i , t in enumerate (begin ) if t is not None ]
1631
1607
end_coef = [floatX (t ) / shp_end [i ] for i , t in enumerate (end ) if t is not None ]
1632
1608
coefs = begin_coef + end_coef
1633
1609
coef = tt .prod (coefs )
1634
1610
else :
1635
1611
raise TypeError (
1636
- "Unrecognized `total_size` type, expected "
1637
- "int or list of ints, got %r" % total_size
1612
+ "Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
1638
1613
)
1639
1614
return tt .as_tensor (floatX (coef ))
1640
1615
@@ -1753,9 +1728,7 @@ def as_tensor(data, name, model, distribution):
1753
1728
testval = testval ,
1754
1729
parent_dist = distribution ,
1755
1730
)
1756
- missing_values = FreeRV (
1757
- name = name + "_missing" , distribution = fakedist , model = model
1758
- )
1731
+ missing_values = FreeRV (name = name + "_missing" , distribution = fakedist , model = model )
1759
1732
constant = tt .as_tensor_variable (data .filled ())
1760
1733
1761
1734
dataTensor = tt .set_subtensor (constant [data .mask .nonzero ()], missing_values )
@@ -1854,14 +1827,11 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
1854
1827
"""
1855
1828
self .name = name
1856
1829
self .data = {
1857
- name : as_tensor (data , name , model , distribution )
1858
- for name , data in data .items ()
1830
+ name : as_tensor (data , name , model , distribution ) for name , data in data .items ()
1859
1831
}
1860
1832
1861
1833
self .missing_values = [
1862
- datum .missing_values
1863
- for datum in self .data .values ()
1864
- if datum .missing_values is not None
1834
+ datum .missing_values for datum in self .data .values () if datum .missing_values is not None
1865
1835
]
1866
1836
self .logp_elemwiset = distribution .logp (** self .data )
1867
1837
# The logp might need scaling in minibatches.
@@ -1871,9 +1841,7 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
1871
1841
self .total_size = total_size
1872
1842
self .model = model
1873
1843
self .distribution = distribution
1874
- self .scaling = _get_scaling (
1875
- total_size , self .logp_elemwiset .shape , self .logp_elemwiset .ndim
1876
- )
1844
+ self .scaling = _get_scaling (total_size , self .logp_elemwiset .shape , self .logp_elemwiset .ndim )
1877
1845
1878
1846
# Make hashable by id for draw_values
1879
1847
def __hash__ (self ):
@@ -1888,7 +1856,7 @@ def __ne__(self, other):
1888
1856
return not self == other
1889
1857
1890
1858
1891
- def _walk_up_rv (rv , formatting = ' plain' ):
1859
+ def _walk_up_rv (rv , formatting = " plain" ):
1892
1860
"""Walk up theano graph to get inputs for deterministic RV."""
1893
1861
all_rvs = []
1894
1862
parents = list (itertools .chain (* [j .inputs for j in rv .get_parents ()]))
@@ -1903,21 +1871,23 @@ def _walk_up_rv(rv, formatting='plain'):
1903
1871
1904
1872
1905
1873
class DeterministicWrapper (tt .TensorVariable ):
1906
- def _str_repr (self , formatting = ' plain' ):
1907
- if formatting == ' latex' :
1874
+ def _str_repr (self , formatting = " plain" ):
1875
+ if formatting == " latex" :
1908
1876
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$" .format (
1909
- name = self .name , args = r",~" .join (_walk_up_rv (self , formatting = formatting )))
1877
+ name = self .name , args = r",~" .join (_walk_up_rv (self , formatting = formatting ))
1878
+ )
1910
1879
else :
1911
1880
return "{name} ~ Deterministic({args})" .format (
1912
- name = self .name , args = ", " .join (_walk_up_rv (self , formatting = formatting )))
1881
+ name = self .name , args = ", " .join (_walk_up_rv (self , formatting = formatting ))
1882
+ )
1913
1883
1914
1884
def _repr_latex_ (self ):
1915
- return self ._str_repr (formatting = ' latex' )
1885
+ return self ._str_repr (formatting = " latex" )
1916
1886
1917
1887
__latex__ = _repr_latex_
1918
1888
1919
1889
def __str__ (self ):
1920
- return self ._str_repr (formatting = ' plain' )
1890
+ return self ._str_repr (formatting = " plain" )
1921
1891
1922
1892
1923
1893
def Deterministic (name , var , model = None , dims = None ):
@@ -1936,7 +1906,7 @@ def Deterministic(name, var, model=None, dims=None):
1936
1906
var = var .copy (model .name_for (name ))
1937
1907
model .deterministics .append (var )
1938
1908
model .add_random_variable (var , dims )
1939
- var .__class__ = DeterministicWrapper # adds str and latex functionality
1909
+ var .__class__ = DeterministicWrapper # adds str and latex functionality
1940
1910
1941
1911
return var
1942
1912
@@ -2030,7 +2000,7 @@ def as_iterargs(data):
2030
2000
2031
2001
def all_continuous (vars ):
2032
2002
"""Check that vars not include discrete variables, excepting
2033
- ObservedRVs. """
2003
+ ObservedRVs."""
2034
2004
vars_ = [var for var in vars if not isinstance (var , pm .model .ObservedRV )]
2035
2005
if any ([var .dtype in pm .discrete_types for var in vars_ ]):
2036
2006
return False
0 commit comments