Skip to content

Commit 429276c

Browse files
authored
Merge pull request #2212 from a-rodin/latex_dists
Avoid storing bound methods as variables to prevent pickling problems
2 parents 5ee6682 + 4ae1ec3 commit 429276c

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

pymc3/distributions/transforms.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from . import distribution
66
from ..math import logit, invlogit
77
import numpy as np
8-
from functools import partial
98

109
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
1110
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
@@ -70,8 +69,13 @@ def __init__(self, dist, transform, *args, **kwargs):
7069
b = np.hstack(((np.atleast_1d(self.shape) == 1)[:-1], False))
7170
# force the last dim not broadcastable
7271
self.type = tt.TensorType(v.dtype, b)
73-
74-
self._repr_latex_ = partial(dist._repr_latex_, dist=dist)
72+
73+
def _repr_latex_(self, name=None, dist=None):
74+
if name is None:
75+
name = self.name
76+
if dist is None:
77+
dist = self.dist
78+
return dist._repr_latex_(self, name=name, dist=dist)
7579

7680
def logp(self, x):
7781
return (self.dist.logp(self.transform_used.backward(x)) +

pymc3/model.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
22
import threading
33
import six
4-
from functools import partial
54

65
import numpy as np
76
import scipy.sparse as sps
@@ -823,7 +822,14 @@ def __init__(self, type=None, owner=None, index=None, name=None,
823822
methods=['random'],
824823
wrapper=InstanceMethod)
825824

826-
self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution)
825+
def _repr_latex_(self, name=None, dist=None):
826+
if self.distribution is None:
827+
return None
828+
if name is None:
829+
name = self.name
830+
if dist is None:
831+
dist = self.distribution
832+
return self.distribution._repr_latex_(name=name, dist=dist)
827833

828834
@property
829835
def init_value(self):
@@ -916,8 +922,15 @@ def __init__(self, type=None, owner=None, index=None, name=None, data=None,
916922
inputs=[data], outputs=[self])
917923

918924
self.tag.test_value = theano.compile.view_op(data).tag.test_value
919-
920-
self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution)
925+
926+
def _repr_latex_(self, name=None, dist=None):
927+
if self.distribution is None:
928+
return None
929+
if name is None:
930+
name = self.name
931+
if dist is None:
932+
dist = self.distribution
933+
return self.distribution._repr_latex_(name=name, dist=dist)
921934

922935
@property
923936
def init_value(self):
@@ -1016,6 +1029,7 @@ def __init__(self, type=None, owner=None, index=None, name=None,
10161029

10171030
if distribution is not None:
10181031
self.model = model
1032+
self.distribution = distribution
10191033

10201034
transformed_name = get_transformed_name(name, transform)
10211035

@@ -1032,7 +1046,14 @@ def __init__(self, type=None, owner=None, index=None, name=None,
10321046
methods=['random'],
10331047
wrapper=InstanceMethod)
10341048

1035-
self._repr_latex_ = partial(distribution._repr_latex_, name=name, dist=distribution)
1049+
def _repr_latex_(self, name=None, dist=None):
1050+
if self.distribution is None:
1051+
return None
1052+
if name is None:
1053+
name = self.name
1054+
if dist is None:
1055+
dist = self.distribution
1056+
return self.distribution._repr_latex_(name=name, dist=dist)
10361057

10371058
@property
10381059
def init_value(self):

0 commit comments

Comments
 (0)