Skip to content

Commit 5601822

Browse files
update if conditions, formatting defaults and tests
+ latex formatting should be detected by `if "latex" in formatting` to catch both format options + all latex reprs except for an entire model default to "latex_with_params" + tests now cover cases with and without params
1 parent 56f78d5 commit 5601822

File tree

7 files changed

+71
-46
lines changed

7 files changed

+71
-46
lines changed

pymc3/distributions/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
246246
alpha = self.alpha
247247
m = self.m
248248

249-
if formatting == "latex":
249+
if "latex" in formatting:
250250
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
251251
else:
252252
return f"{name} ~ BART(alpha = {alpha}, m = {m})"

pymc3/distributions/bound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,13 @@ def _distr_name_for_repr(self):
157157

158158
def _str_repr(self, **kwargs):
159159
distr_repr = self._wrapped._str_repr(**{**kwargs, "dist": self._wrapped})
160-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
160+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
161161
distr_repr = distr_repr[distr_repr.index(r" \sim") + 6 :]
162162
else:
163163
distr_repr = distr_repr[distr_repr.index(" ~") + 3 :]
164164
self_repr = super()._str_repr(**kwargs)
165165

166-
if "formatting" in kwargs and kwargs["formatting"] == "latex":
166+
if "formatting" in kwargs and "latex" in kwargs["formatting"]:
167167
return self_repr + " -- " + distr_repr
168168
else:
169169
return self_repr + "-" + distr_repr

pymc3/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def __str__(self, **kwargs):
216216
except:
217217
return super().__str__()
218218

219-
def _repr_latex_(self, **kwargs):
219+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
220220
"""Magic method name for IPython to use for LaTeX formatting."""
221-
return self._str_repr(formatting="latex_with_params", **kwargs)
221+
return self._str_repr(formatting=formatting, **kwargs)
222222

223223
def logp_nojac(self, *args, **kwargs):
224224
"""Return the logp, but do not include a jacobian term for transforms.

pymc3/distributions/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
126126
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat
127127
distance = getattr(self.distance, "__name__", self.distance.__class__.__name__)
128128

129-
if formatting == "latex":
129+
if "latex" in formatting:
130130
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
131131
else:
132132
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"

pymc3/model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __rmatmul__(self, other):
6565

6666
def _str_repr(self, name=None, dist=None, formatting="plain"):
6767
if getattr(self, "distribution", None) is None:
68-
if formatting == "latex":
68+
if "latex" in formatting:
6969
return None
7070
else:
7171
return super().__str__()
@@ -76,8 +76,8 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
7676
dist = self.distribution
7777
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)
7878

79-
def _repr_latex_(self, **kwargs):
80-
return self._str_repr(formatting="latex", **kwargs)
79+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
80+
return self._str_repr(formatting=formatting, **kwargs)
8181

8282
def __str__(self, **kwargs):
8383
try:
@@ -1375,8 +1375,8 @@ def check_test_point(self, test_point=None, round_vals=2):
13751375
def _str_repr(self, formatting="plain", **kwargs):
13761376
all_rv = itertools.chain(self.unobserved_RVs, self.observed_RVs)
13771377

1378-
if formatting == "latex":
1379-
rv_reprs = [rv.__latex__() for rv in all_rv]
1378+
if "latex" in formatting:
1379+
rv_reprs = [rv.__latex__(formatting=formatting) for rv in all_rv]
13801380
rv_reprs = [
13811381
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
13821382
for rv_repr in rv_reprs
@@ -1407,8 +1407,8 @@ def _str_repr(self, formatting="plain", **kwargs):
14071407
def __str__(self, **kwargs):
14081408
return self._str_repr(formatting="plain", **kwargs)
14091409

1410-
def _repr_latex_(self, **kwargs):
1411-
return self._str_repr(formatting="latex", **kwargs)
1410+
def _repr_latex_(self, *, formatting="latex", **kwargs):
1411+
return self._str_repr(formatting=formatting, **kwargs)
14121412

14131413
__latex__ = _repr_latex_
14141414

@@ -1893,8 +1893,8 @@ def _str_repr(self, formatting="plain"):
18931893
return f"{self.name} ~ Deterministic({args})"
18941894
return f"{self.name} ~ Deterministic"
18951895

1896-
def _repr_latex_(self):
1897-
return self._str_repr(formatting="latex_with_params")
1896+
def _repr_latex_(self, *, formatting="latex_with_params", **kwargs):
1897+
return self._str_repr(formatting=formatting)
18981898

18991899
__latex__ = _repr_latex_
19001900

pymc3/tests/test_distributions.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1800,58 +1800,83 @@ def setup_class(self):
18001800
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
18011801

18021802
self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var]
1803-
self.expected_latex = (
1804-
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1805-
r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$",
1806-
r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$",
1807-
r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1808-
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
1809-
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
1810-
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1811-
r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$",
1812-
r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$",
1813-
)
1814-
self.expected_str = (
1815-
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
1816-
r"sigma ~ HalfNormal(sigma=1.0)",
1817-
r"mu ~ Deterministic(alpha, Constant, beta)",
1818-
r"beta ~ Normal(mu=0.0, sigma=10.0)",
1819-
r"Z ~ MvNormal(mu=array, chol_cov=array)",
1820-
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
1821-
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
1822-
r"kron_normal ~ KroneckerNormal(mu=array)",
1823-
r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)",
1824-
)
1803+
self.expected = {
1804+
"latex": (
1805+
r"$\text{alpha} \sim \text{Normal}$",
1806+
r"$\text{sigma} \sim \text{HalfNormal}$",
1807+
r"$\text{mu} \sim \text{Deterministic}$",
1808+
r"$\text{beta} \sim \text{Normal}$",
1809+
r"$\text{Z} \sim \text{MvNormal}$",
1810+
r"$\text{Y_obs} \sim \text{Normal}$",
1811+
r"$\text{bound_var} \sim \text{Bound}$ -- \text{Normal}$",
1812+
r"$\text{kron_normal} \sim \text{KroneckerNormal}$",
1813+
r"$\text{mat_normal} \sim \text{MatrixNormal}$",
1814+
),
1815+
"plain": (
1816+
r"alpha ~ Normal",
1817+
r"sigma ~ HalfNormal",
1818+
r"mu ~ Deterministic",
1819+
r"beta ~ Normal",
1820+
r"Z ~ MvNormal",
1821+
r"Y_obs ~ Normal",
1822+
r"bound_var ~ Bound-Normal",
1823+
r"kron_normal ~ KroneckerNormal",
1824+
r"mat_normal ~ MatrixNormal",
1825+
),
1826+
"latex_with_params": (
1827+
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1828+
r"$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$",
1829+
r"$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$",
1830+
r"$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1831+
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
1832+
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
1833+
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
1834+
r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$",
1835+
r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$",
1836+
),
1837+
"plain_with_params": (
1838+
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
1839+
r"sigma ~ HalfNormal(sigma=1.0)",
1840+
r"mu ~ Deterministic(alpha, Constant, beta)",
1841+
r"beta ~ Normal(mu=0.0, sigma=10.0)",
1842+
r"Z ~ MvNormal(mu=array, chol_cov=array)",
1843+
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
1844+
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
1845+
r"kron_normal ~ KroneckerNormal(mu=array)",
1846+
r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)",
1847+
),
1848+
}
18251849

18261850
def test__repr_latex_(self):
1827-
for distribution, tex in zip(self.distributions, self.expected_latex):
1851+
for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]):
18281852
assert distribution._repr_latex_() == tex
18291853

18301854
model_tex = self.model._repr_latex_()
18311855

1832-
for tex in self.expected_latex: # make sure each variable is in the model
1856+
# make sure each variable is in the model
1857+
for tex in self.expected["latex"]:
18331858
for segment in tex.strip("$").split(r"\sim"):
18341859
assert segment in model_tex
18351860

18361861
def test___latex__(self):
1837-
for distribution, tex in zip(self.distributions, self.expected_latex):
1862+
for distribution, tex in zip(self.distributions, self.expected["latex_with_params"]):
18381863
assert distribution._repr_latex_() == distribution.__latex__()
18391864
assert self.model._repr_latex_() == self.model.__latex__()
18401865

18411866
def test___str__(self):
1842-
for distribution, str_repr in zip(self.distributions, self.expected_str):
1867+
for distribution, str_repr in zip(self.distributions, self.expected["plain"]):
18431868
assert distribution.__str__() == str_repr
18441869

18451870
model_str = self.model.__str__()
1846-
for str_repr in self.expected_str:
1871+
for str_repr in self.expected["plain"]:
18471872
assert str_repr in model_str
18481873

18491874
def test_str(self):
1850-
for distribution, str_repr in zip(self.distributions, self.expected_str):
1875+
for distribution, str_repr in zip(self.distributions, self.expected["plain"]):
18511876
assert str(distribution) == str_repr
18521877

18531878
model_str = str(self.model)
1854-
for str_repr in self.expected_str:
1879+
for str_repr in self.expected["plain"]:
18551880
assert str_repr in model_str
18561881

18571882

pymc3/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def get_repr_for_variable(variable, formatting="plain"):
137137
for item in variable.get_parents()[0].inputs
138138
]
139139
# do not escape_latex these, since it is not idempotent
140-
if formatting == "latex":
140+
if "latex" in formatting:
141141
return "f({args})".format(
142142
args=",~".join([n for n in names if isinstance(n, str)])
143143
)
@@ -152,7 +152,7 @@ def get_repr_for_variable(variable, formatting="plain"):
152152
return value.item()
153153
return "array"
154154

155-
if formatting == "latex":
155+
if "latex" in formatting:
156156
return fr"\text{{{name}}}"
157157
else:
158158
return name

0 commit comments

Comments
 (0)