Skip to content

Commit 56f78d5

Browse files
add str/repr formatting options and change defaults accordingly
+ new options "latex_with_params" and "plain_with_params" replace the current default behavior of including input parameters (going back to default behaviour of the 3.9.3 release) + __latex__ and _repr_latex default to "latex_with_params" + __str__ and _str_repr default to "plain" + new formatting kwarg for model_to_graphviz enables switching between "plain" (default) and "plain_with_params"
1 parent a6295a3 commit 56f78d5

File tree

3 files changed

+54
-25
lines changed

3 files changed

+54
-25
lines changed

pymc3/distributions/distribution.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,34 +164,51 @@ def _distr_name_for_repr(self):
164164
return self.__class__.__name__
165165

166166
def _str_repr(self, name=None, dist=None, formatting="plain"):
167-
"""Generate string representation for this distribution, optionally
167+
"""
168+
Generate string representation for this distribution, optionally
168169
including LaTeX markup (formatting='latex').
170+
171+
Parameters
172+
----------
173+
name : str
174+
name of the distribution
175+
dist : Distribution
176+
the distribution object
177+
formatting : str
178+
one of { "latex", "plain", "latex_with_params", "plain_with_params" }
169179
"""
170180
if dist is None:
171181
dist = self
172182
if name is None:
173183
name = "[unnamed]"
184+
supported_formattings = {"latex", "plain", "latex_with_params", "plain_with_params"}
185+
if not formatting in supported_formattings:
186+
raise ValueError(f"Unsupported formatting ''. Choose one of {supported_formattings}.")
174187

175188
param_names = self._distr_parameters_for_repr()
176189
param_values = [
177190
get_repr_for_variable(getattr(dist, x), formatting=formatting) for x in param_names
178191
]
179192

180-
if formatting == "latex":
193+
if "latex" in formatting:
181194
param_string = ",~".join(
182195
[fr"\mathit{{{name}}}={value}" for name, value in zip(param_names, param_values)]
183196
)
184-
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
185-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
197+
if formatting == "latex_with_params":
198+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(
199+
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
200+
)
201+
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}$".format(
202+
var_name=name, distr_name=dist._distr_name_for_repr()
186203
)
187204
else:
188-
# 'plain' is default option
205+
# one of the plain formattings
189206
param_string = ", ".join(
190207
[f"{name}={value}" for name, value in zip(param_names, param_values)]
191208
)
192-
return "{var_name} ~ {distr_name}({params})".format(
193-
var_name=name, distr_name=dist._distr_name_for_repr(), params=param_string
194-
)
209+
if formatting == "plain_with_params":
210+
return f"{name} ~ {dist._distr_name_for_repr()}({param_string})"
211+
return f"{name} ~ {dist._distr_name_for_repr()}"
195212

196213
def __str__(self, **kwargs):
197214
try:
@@ -201,7 +218,7 @@ def __str__(self, **kwargs):
201218

202219
def _repr_latex_(self, **kwargs):
203220
"""Magic method name for IPython to use for LaTeX formatting."""
204-
return self._str_repr(formatting="latex", **kwargs)
221+
return self._str_repr(formatting="latex_with_params", **kwargs)
205222

206223
def logp_nojac(self, *args, **kwargs):
207224
"""Return the logp, but do not include a jacobian term for transforms.

pymc3/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,24 +1874,27 @@ def _walk_up_rv(rv, formatting="plain"):
18741874
all_rvs.extend(_walk_up_rv(parent, formatting=formatting))
18751875
else:
18761876
name = rv.name if rv.name else "Constant"
1877-
fmt = r"\text{{{name}}}" if formatting == "latex" else "{name}"
1877+
fmt = r"\text{{{name}}}" if "latex" in formatting else "{name}"
18781878
all_rvs.append(fmt.format(name=name))
18791879
return all_rvs
18801880

18811881

18821882
class DeterministicWrapper(tt.TensorVariable):
18831883
def _str_repr(self, formatting="plain"):
1884-
if formatting == "latex":
1885-
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1886-
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1887-
)
1884+
if "latex" in formatting:
1885+
if formatting == "latex_with_params":
1886+
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1887+
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1888+
)
1889+
return fr"$\text{{{self.name}}} \sim \text{{Deterministic}}$"
18881890
else:
1889-
return "{name} ~ Deterministic({args})".format(
1890-
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
1891-
)
1891+
if formatting == "plain_with_params":
1892+
args = ", ".join(_walk_up_rv(self, formatting=formatting))
1893+
return f"{self.name} ~ Deterministic({args})"
1894+
return f"{self.name} ~ Deterministic"
18921895

18931896
def _repr_latex_(self):
1894-
return self._str_repr(formatting="latex")
1897+
return self._str_repr(formatting="latex_with_params")
18951898

18961899
__latex__ = _repr_latex_
18971900

pymc3/model_graph.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def update_input_map(key: str, val: Set[VarName]):
121121
pass
122122
return input_map
123123

124-
def _make_node(self, var_name, graph):
124+
def _make_node(self, var_name, graph, *, formatting: str = "plain"):
125125
"""Attaches the given variable to a graphviz Digraph"""
126126
v = self.model[var_name]
127127

@@ -146,7 +146,7 @@ def _make_node(self, var_name, graph):
146146
elif isinstance(v, SharedVariable):
147147
label = f"{var_name}\n~\nData"
148148
else:
149-
label = str(v).replace(" ~ ", "\n~\n")
149+
label = v._str_repr(formatting=formatting).replace(" ~ ", "\n~\n")
150150

151151
graph.node(var_name.replace(":", "&"), label, **attrs)
152152

@@ -181,7 +181,7 @@ def get_plates(self):
181181
plates[shape].add(var_name)
182182
return plates
183183

184-
def make_graph(self):
184+
def make_graph(self, formatting: str = "plain"):
185185
"""Make graphviz Digraph of PyMC3 model
186186
187187
Returns
@@ -205,20 +205,20 @@ def make_graph(self):
205205
# must be preceded by 'cluster' to get a box around it
206206
with graph.subgraph(name="cluster" + label) as sub:
207207
for var_name in var_names:
208-
self._make_node(var_name, sub)
208+
self._make_node(var_name, sub, formatting=formatting)
209209
# plate label goes bottom right
210210
sub.attr(label=label, labeljust="r", labelloc="b", style="rounded")
211211
else:
212212
for var_name in var_names:
213-
self._make_node(var_name, graph)
213+
self._make_node(var_name, graph, formatting=formatting)
214214

215215
for key, values in self.make_compute_graph().items():
216216
for value in values:
217217
graph.edge(value.replace(":", "&"), key.replace(":", "&"))
218218
return graph
219219

220220

221-
def model_to_graphviz(model=None):
221+
def model_to_graphviz(model=None, *, formatting: str = "plain"):
222222
"""Produce a graphviz Digraph from a PyMC3 model.
223223
224224
Requires graphviz, which may be installed most easily with
@@ -228,6 +228,15 @@ def model_to_graphviz(model=None):
228228
and then `pip install graphviz` to get the python bindings. See
229229
http://graphviz.readthedocs.io/en/stable/manual.html
230230
for more information.
231+
232+
Parameters
233+
----------
234+
model : pm.Model
235+
The model to plot. Not required when called from inside a modelcontext.
236+
formatting : str
237+
one of { "plain", "plain_with_params" }
231238
"""
239+
if not "plain" in formatting:
240+
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
232241
model = pm.modelcontext(model)
233-
return ModelGraph(model).make_graph()
242+
return ModelGraph(model).make_graph(formatting=formatting)

0 commit comments

Comments
 (0)