-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor model graph and allow suppressing dim lengths #7392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
be7a7f5
f96502f
962eab8
38428de
f492a03
d1b5390
6667557
559dc42
aec7ae5
6d8b2ee
382a573
a2e9e60
2411da0
633a8cc
950409d
b9bcf92
e30f6d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,10 @@ | |
import warnings | ||
|
||
from collections import defaultdict | ||
from collections.abc import Callable, Iterable, Sequence | ||
from collections.abc import Callable, Iterable | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from itertools import zip_longest | ||
from os import path | ||
from typing import Any | ||
|
||
|
@@ -39,6 +41,37 @@ | |
) | ||
|
||
|
||
@dataclass | ||
class PlateMeta: | ||
names: tuple[str] | ||
sizes: tuple[int] | ||
|
||
def __hash__(self): | ||
return hash((self.names, self.sizes)) | ||
|
||
|
||
def create_plate_label( | ||
var_name: str, | ||
plate_meta: PlateMeta, | ||
include_size: bool = True, | ||
) -> str: | ||
def create_label(d: int, dname: str, dlen: int): | ||
if not dname: | ||
return f"{dlen}" | ||
|
||
label = f"{dname}" | ||
|
||
if include_size: | ||
label = f"{label} ({dlen})" | ||
|
||
return label | ||
|
||
values = enumerate( | ||
zip_longest(plate_meta.names, plate_meta.sizes, fillvalue=None), | ||
) | ||
return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values) | ||
|
||
|
||
def fast_eval(var): | ||
return function([], var, mode="FAST_COMPILE")() | ||
|
||
|
@@ -53,6 +86,21 @@ class NodeType(str, Enum): | |
DATA = "Data" | ||
|
||
|
||
@dataclass | ||
class NodeMeta: | ||
var: TensorVariable | ||
node_type: NodeType | ||
|
||
def __hash__(self): | ||
return hash(self.var.name) | ||
|
||
|
||
@dataclass | ||
class Plate: | ||
meta: PlateMeta | ||
variables: list[NodeMeta] | ||
|
||
|
||
GraphvizNodeKwargs = dict[str, Any] | ||
NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs] | ||
|
||
|
@@ -265,31 +313,26 @@ def make_compute_graph( | |
|
||
def _make_node( | ||
self, | ||
var_name, | ||
graph, | ||
node: NodeMeta, | ||
*, | ||
node_formatters: NodeTypeFormatterMapping, | ||
nx=False, | ||
cluster=False, | ||
add_node: Callable[[str, ...], None], | ||
cluster: bool = False, | ||
formatting: str = "plain", | ||
): | ||
"""Attaches the given variable to a graphviz or networkx Digraph""" | ||
v = self.model[var_name] | ||
|
||
node_type = get_node_type(var_name, self.model) | ||
node_formatter = node_formatters[node_type] | ||
|
||
kwargs = node_formatter(v) | ||
node_formatter = node_formatters[node.node_type] | ||
kwargs = node_formatter(node.var) | ||
|
||
if cluster: | ||
kwargs["cluster"] = cluster | ||
|
||
if nx: | ||
graph.add_node(var_name.replace(":", "&"), **kwargs) | ||
else: | ||
graph.node(var_name.replace(":", "&"), **kwargs) | ||
add_node(node.var.name.replace(":", "&"), **kwargs) | ||
|
||
def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, set[VarName]]: | ||
def get_plates( | ||
self, | ||
var_names: Iterable[VarName] | None = None, | ||
) -> list[Plate]: | ||
"""Rough but surprisingly accurate plate detection. | ||
|
||
Just groups by the shape of the underlying distribution. Will be wrong | ||
|
@@ -302,32 +345,67 @@ def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, se | |
""" | ||
plates = defaultdict(set) | ||
|
||
# TODO: Evaluate all RV shapes and dim_length at once. | ||
# This should help to find discrepancies, and | ||
# avoids unnecessary function compiles for deetermining labels. | ||
# TODO: Evaluate all RV shapes at once | ||
# This should help find discrepencies, and | ||
# avoids unnecessary function compiles for determining labels. | ||
dim_lengths: dict[str, int] = { | ||
name: fast_eval(value).item() for name, value in self.model.dim_lengths.items() | ||
} | ||
|
||
for var_name in self.vars_to_plot(var_names): | ||
v = self.model[var_name] | ||
shape: Sequence[int] = fast_eval(v.shape) | ||
dim_labels = [] | ||
shape: tuple[int, ...] = tuple(fast_eval(v.shape)) | ||
if var_name in self.model.named_vars_to_dims: | ||
# The RV is associated with `dims` information. | ||
names = [] | ||
sizes = [] | ||
for d, dname in enumerate(self.model.named_vars_to_dims[var_name]): | ||
if dname is None: | ||
# Unnamed dimension in a `dims` tuple! | ||
dlen = shape[d] | ||
dname = f"{var_name}_dim{d}" | ||
williambdean marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
dlen = fast_eval(self.model.dim_lengths[dname]) | ||
dim_labels.append(f"{dname} ({dlen})") | ||
plate_label = " x ".join(dim_labels) | ||
names.append(dname) | ||
sizes.append(dim_lengths.get(dname, shape[d])) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
plate_meta = PlateMeta( | ||
names=tuple(names), | ||
sizes=tuple(sizes), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this tbh. Are we creating one plate per variable? But a plate can contain multiple variables? Also names is ambiguous, it is dim_names? We should name it like that to distinguish from var_names? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah plates are hashable... so you mutate the same thing... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just working with what was there. The historical { str: set[VarName] } is created with loop which I changed to { PlateMeta : set[NodeMeta] } There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic feels rather convoluted to be honest. Maybe we can take a step back and see what is actually needed. Would the code be more readable if we didn't try to do both things at once? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Edit: Updated comment above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defaultdict with dims_names and dim_lengths is same as what is currently happening. But there is a wrapper class around it. Personally, I find the class helpful and more user friendly. But I could be wrong For instance, Plate(
DimInfo(names=("obs", "covariate"), sizes=(10, 5)),
variables=[
NodeInfo(X, node_type=DATA),
NodeInfo(X_transform, node_type=DETERMINISTIC),
NodeInfo(tvp, node_type=FREE_RV),
]
) over (("obs", "covariate"), (10, 5), (X, X_transform, tvp), (DATA, DETERMINSTIC, FREE_RV)) lines up a bit better in my mind that the first two are related objects and the last two are related objects as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, just thinking of how easy we make it for users to define their custom stuff. Either way seems manageable |
||
else: | ||
# The RV has no `dims` information. | ||
dim_labels = [str(x) for x in shape] | ||
plate_label = " x ".join(map(str, shape)) | ||
plates[plate_label].add(var_name) | ||
plate_meta = PlateMeta( | ||
names=(), | ||
sizes=tuple(shape), | ||
) | ||
|
||
v = self.model[var_name] | ||
node_type = get_node_type(var_name, self.model) | ||
var = NodeMeta(var=v, node_type=node_type) | ||
plates[plate_meta].add(var) | ||
|
||
return [ | ||
Plate(meta=plate_meta, variables=list(variables)) | ||
for plate_meta, variables in plates.items() | ||
] | ||
|
||
def edges( | ||
self, | ||
var_names: Iterable[VarName] | None = None, | ||
) -> list[tuple[VarName, VarName]]: | ||
"""Get edges between the variables in the model. | ||
|
||
Parameters | ||
---------- | ||
var_names : iterable of str, optional | ||
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph | ||
|
||
Returns | ||
------- | ||
list of tuple | ||
List of edges between the variables in the model. | ||
|
||
return dict(plates) | ||
""" | ||
return [ | ||
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) | ||
for child, parents in self.make_compute_graph(var_names=var_names).items() | ||
for parent in parents | ||
] | ||
|
||
def make_graph( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would just remove calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would make it more modular, in that if you find a way to create your own plates and edges, you can just pass it to the functions that then display it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, sure. I think that makes sense then. The dictionary of {PlateMeta : set[NodeMeta]} is a bit weird and hard to work with. i.e. set is not subscritable and looking up by PlateMeta key is a bit tricky. I was thinking of having another object, @dataclass
class Plate:
plate_meta: PlateMeta
nodes: list[NodeMeta] and that would be in the input to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, does it make sense as a method still? Do you see model_to_graphviz taking this input as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lost track of the specific methods we're discussing. My low resolution guess was that once we have the plates / edges we can just pass them to a function that uses those to render graphviz or networkx graphs. Let me know if you were asking about something else or see a problem (or no point) with that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Let me push something up and you can give feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just pushed. |
||
self, | ||
|
@@ -337,6 +415,7 @@ def make_graph( | |
figsize=None, | ||
dpi=300, | ||
node_formatters: NodeTypeFormatterMapping | None = None, | ||
include_shape_size: bool = True, | ||
): | ||
"""Make graphviz Digraph of PyMC model | ||
|
||
|
@@ -357,26 +436,35 @@ def make_graph( | |
node_formatters = update_node_formatters(node_formatters) | ||
|
||
graph = graphviz.Digraph(self.model.name) | ||
for plate_label, all_var_names in self.get_plates(var_names).items(): | ||
if plate_label: | ||
for plate in self.get_plates(var_names): | ||
plate_meta = plate.meta | ||
all_vars = plate.variables | ||
if plate_meta.names or plate_meta.sizes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we simplify? Could plate_meta be None for the scalar variables? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic would still be needed somewhere. Likely in if plate_meta: # Truthy if sizes or names
# plate_meta has sizes or names that are not empty tuples There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have that information when you defined the plate.meta no? Can't you do it immediately? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I changed to have it happen in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's enough to check for We should rename those to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So IIUC, scalars should belong to a "Plate" with dim_names = (), and dim_lengths = ()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And now I understand your approach and I think it was better like you did. The Sorry I got confused by the names of the things |
||
# must be preceded by 'cluster' to get a box around it | ||
plate_label = create_plate_label( | ||
all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
) | ||
with graph.subgraph(name="cluster" + plate_label) as sub: | ||
for var_name in all_var_names: | ||
for var in all_vars: | ||
self._make_node( | ||
var_name, sub, formatting=formatting, node_formatters=node_formatters | ||
node=var, | ||
formatting=formatting, | ||
node_formatters=node_formatters, | ||
add_node=sub.node, | ||
) | ||
# plate label goes bottom right | ||
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded") | ||
else: | ||
for var_name in all_var_names: | ||
for var in all_vars: | ||
self._make_node( | ||
var_name, graph, formatting=formatting, node_formatters=node_formatters | ||
node=var, | ||
formatting=formatting, | ||
node_formatters=node_formatters, | ||
add_node=graph.node, | ||
) | ||
|
||
for child, parents in self.make_compute_graph(var_names=var_names).items(): | ||
# parents is a set of rv names that precede child rv nodes | ||
for parent in parents: | ||
graph.edge(parent.replace(":", "&"), child.replace(":", "&")) | ||
for child, parent in self.edges(var_names=var_names): | ||
graph.edge(parent, child) | ||
|
||
if save is not None: | ||
width, height = (None, None) if figsize is None else figsize | ||
|
@@ -397,6 +485,7 @@ def make_networkx( | |
var_names: Iterable[VarName] | None = None, | ||
formatting: str = "plain", | ||
node_formatters: NodeTypeFormatterMapping | None = None, | ||
include_shape_size: bool = True, | ||
): | ||
"""Make networkx Digraph of PyMC model | ||
|
||
|
@@ -417,20 +506,24 @@ def make_networkx( | |
node_formatters = update_node_formatters(node_formatters) | ||
|
||
graphnetwork = networkx.DiGraph(name=self.model.name) | ||
for plate_label, all_var_names in self.get_plates(var_names).items(): | ||
if plate_label: | ||
for plate in self.get_plates(var_names): | ||
plate_meta = plate.meta | ||
all_vars = plate.variables | ||
if plate_meta.names or plate_meta.sizes: | ||
# # must be preceded by 'cluster' to get a box around it | ||
|
||
plate_label = create_plate_label( | ||
all_vars[0].var.name, plate_meta, include_size=include_shape_size | ||
) | ||
subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) | ||
|
||
for var_name in all_var_names: | ||
for var in all_vars: | ||
self._make_node( | ||
var_name, | ||
subgraphnetwork, | ||
nx=True, | ||
node=var, | ||
node_formatters=node_formatters, | ||
cluster="cluster" + plate_label, | ||
formatting=formatting, | ||
add_node=subgraphnetwork.add_node, | ||
) | ||
for sgn in subgraphnetwork.nodes: | ||
networkx.set_node_attributes( | ||
|
@@ -446,19 +539,17 @@ def make_networkx( | |
networkx.set_node_attributes(graphnetwork, node_data) | ||
graphnetwork.graph["name"] = self.model.name | ||
else: | ||
for var_name in all_var_names: | ||
for var in all_vars: | ||
self._make_node( | ||
var_name, | ||
graphnetwork, | ||
nx=True, | ||
node=var, | ||
formatting=formatting, | ||
node_formatters=node_formatters, | ||
add_node=graphnetwork.add_node, | ||
) | ||
|
||
for child, parents in self.make_compute_graph(var_names=var_names).items(): | ||
# parents is a set of rv names that precede child rv nodes | ||
for parent in parents: | ||
graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&")) | ||
for child, parents in self.edges(var_names=var_names): | ||
graphnetwork.add_edge(parents, child) | ||
|
||
return graphnetwork | ||
|
||
|
||
|
@@ -468,6 +559,7 @@ def model_to_networkx( | |
var_names: Iterable[VarName] | None = None, | ||
formatting: str = "plain", | ||
node_formatters: NodeTypeFormatterMapping | None = None, | ||
include_shape_size: bool = True, | ||
): | ||
"""Produce a networkx Digraph from a PyMC model. | ||
|
||
|
@@ -493,6 +585,8 @@ def model_to_networkx( | |
A dictionary mapping node types to functions that return a dictionary of node attributes. | ||
Check out the networkx documentation for more information | ||
how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html | ||
include_shape_size : bool | ||
Include the shape size in the plate label. Default is True. | ||
|
||
Examples | ||
-------- | ||
|
@@ -541,7 +635,10 @@ def model_to_networkx( | |
) | ||
model = pm.modelcontext(model) | ||
return ModelGraph(model).make_networkx( | ||
var_names=var_names, formatting=formatting, node_formatters=node_formatters | ||
var_names=var_names, | ||
formatting=formatting, | ||
node_formatters=node_formatters, | ||
include_shape_size=include_shape_size, | ||
) | ||
|
||
|
||
|
@@ -554,6 +651,7 @@ def model_to_graphviz( | |
figsize: tuple[int, int] | None = None, | ||
dpi: int = 300, | ||
node_formatters: NodeTypeFormatterMapping | None = None, | ||
include_shape_size: bool = True, | ||
): | ||
"""Produce a graphviz Digraph from a PyMC model. | ||
|
||
|
@@ -585,6 +683,8 @@ def model_to_graphviz( | |
A dictionary mapping node types to functions that return a dictionary of node attributes. | ||
Check out graphviz documentation for more information on available | ||
attributes. https://graphviz.org/docs/nodes/ | ||
include_shape_size : bool | ||
Include the shape size in the plate label. Default is True. | ||
|
||
Examples | ||
-------- | ||
|
@@ -646,4 +746,5 @@ def model_to_graphviz( | |
figsize=figsize, | ||
dpi=dpi, | ||
node_formatters=node_formatters, | ||
include_shape_size=include_shape_size, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.