Skip to content

Refactor model graph and allow suppressing dim lengths #7392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 3, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 159 additions & 58 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import warnings

from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from itertools import zip_longest
from os import path
from typing import Any

Expand All @@ -39,6 +41,37 @@
)


@dataclass
class PlateMeta:
names: tuple[str]
sizes: tuple[int]

def __hash__(self):
return hash((self.names, self.sizes))


def create_plate_label(
var_name: str,
plate_meta: PlateMeta,
include_size: bool = True,
) -> str:
def create_label(d: int, dname: str, dlen: int):
if not dname:
return f"{dlen}"

label = f"{dname}"

if include_size:
label = f"{label} ({dlen})"

return label

values = enumerate(
zip_longest(plate_meta.names, plate_meta.sizes, fillvalue=None),
)
return " x ".join(create_label(d, dname, dlen) for d, (dname, dlen) in values)


def fast_eval(var):
return function([], var, mode="FAST_COMPILE")()

Expand All @@ -53,6 +86,21 @@ class NodeType(str, Enum):
DATA = "Data"


@dataclass
class NodeMeta:
var: TensorVariable
node_type: NodeType

def __hash__(self):
return hash(self.var.name)


@dataclass
class Plate:
meta: PlateMeta
variables: list[NodeMeta]


GraphvizNodeKwargs = dict[str, Any]
NodeFormatter = Callable[[TensorVariable], GraphvizNodeKwargs]

Expand Down Expand Up @@ -265,31 +313,26 @@ def make_compute_graph(

def _make_node(
self,
var_name,
graph,
node: NodeMeta,
*,
node_formatters: NodeTypeFormatterMapping,
nx=False,
cluster=False,
add_node: Callable[[str, ...], None],
cluster: bool = False,
formatting: str = "plain",
):
"""Attaches the given variable to a graphviz or networkx Digraph"""
v = self.model[var_name]

node_type = get_node_type(var_name, self.model)
node_formatter = node_formatters[node_type]

kwargs = node_formatter(v)
node_formatter = node_formatters[node.node_type]
kwargs = node_formatter(node.var)

if cluster:
kwargs["cluster"] = cluster

if nx:
graph.add_node(var_name.replace(":", "&"), **kwargs)
else:
graph.node(var_name.replace(":", "&"), **kwargs)
add_node(node.var.name.replace(":", "&"), **kwargs)

def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, set[VarName]]:
def get_plates(
self,
var_names: Iterable[VarName] | None = None,
) -> list[Plate]:
"""Rough but surprisingly accurate plate detection.

Just groups by the shape of the underlying distribution. Will be wrong
Expand All @@ -302,32 +345,67 @@ def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, se
"""
plates = defaultdict(set)

# TODO: Evaluate all RV shapes and dim_length at once.
# This should help to find discrepancies, and
# avoids unnecessary function compiles for deetermining labels.
# TODO: Evaluate all RV shapes at once
# This should help find discrepencies, and
# avoids unnecessary function compiles for determining labels.
dim_lengths: dict[str, int] = {
name: fast_eval(value).item() for name, value in self.model.dim_lengths.items()
}

for var_name in self.vars_to_plot(var_names):
v = self.model[var_name]
shape: Sequence[int] = fast_eval(v.shape)
dim_labels = []
shape: tuple[int, ...] = tuple(fast_eval(v.shape))
if var_name in self.model.named_vars_to_dims:
# The RV is associated with `dims` information.
names = []
sizes = []
for d, dname in enumerate(self.model.named_vars_to_dims[var_name]):
if dname is None:
# Unnamed dimension in a `dims` tuple!
dlen = shape[d]
dname = f"{var_name}_dim{d}"
else:
dlen = fast_eval(self.model.dim_lengths[dname])
dim_labels.append(f"{dname} ({dlen})")
plate_label = " x ".join(dim_labels)
names.append(dname)
sizes.append(dim_lengths.get(dname, shape[d]))

plate_meta = PlateMeta(
names=tuple(names),
sizes=tuple(sizes),
)
Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this tbh. Are we creating one plate per variable? But a plate can contain multiple variables?

Also names is ambiguous, it is dim_names? We should name it like that to distinguish from var_names?
Also sizes -> dim_lengths

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah plates are hashable... so you mutate the same thing...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just working with what was there. The historical { str: set[VarName] } is created with loop which I changed to { PlateMeta : set[NodeMeta] }
But switched to list[Plate] ultimately.
Ideally, there could be more straight-foward path to list[Plate]

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic feels rather convoluted to be honest. Maybe we can take a step back and see what is actually needed.
Step1: Collect the dim names and dim lengths of every variable we want to plot. This seems simple enough, and we can do in a loop
Step2: Merge variables that have identical dim_names and dim_lengths into "plates". The hashable Plate thing may be a good trick to achieve that, or just a defaultdict with keys: tuple[dim_names, dim_lengths]

Would the code be more readable if we didn't try to do both things at once?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Updated comment above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defaultdict with dims_names and dim_lengths is same as what is currently happening. But there is a wrapper class around it. Personally, I find the class helpful and more user friendly. But I could be wrong

For instance,

Plate(
    DimInfo(names=("obs", "covariate"), sizes=(10, 5)), 
    variables=[
        NodeInfo(X, node_type=DATA), 
        NodeInfo(X_transform, node_type=DETERMINISTIC), 
        NodeInfo(tvp, node_type=FREE_RV),
    ]
)

over

(("obs", "covariate"), (10, 5), (X, X_transform, tvp), (DATA, DETERMINSTIC, FREE_RV))

lines up a bit better in my mind that the first two are related objects and the last two are related objects as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just thinking of how easy we make it for users to define their custom stuff. Either way seems manageable

else:
# The RV has no `dims` information.
dim_labels = [str(x) for x in shape]
plate_label = " x ".join(map(str, shape))
plates[plate_label].add(var_name)
plate_meta = PlateMeta(
names=(),
sizes=tuple(shape),
)

v = self.model[var_name]
node_type = get_node_type(var_name, self.model)
var = NodeMeta(var=v, node_type=node_type)
plates[plate_meta].add(var)

return [
Plate(meta=plate_meta, variables=list(variables))
for plate_meta, variables in plates.items()
]

def edges(
self,
var_names: Iterable[VarName] | None = None,
) -> list[tuple[VarName, VarName]]:
"""Get edges between the variables in the model.

Parameters
----------
var_names : iterable of str, optional
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph

Returns
-------
list of tuple
List of edges between the variables in the model.

return dict(plates)
"""
return [
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
for child, parents in self.make_compute_graph(var_names=var_names).items()
for parent in parents
]

def make_graph(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should make_graph and make_networkx now be functions that take plates and edges as inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would just remove calling get_plates and edges methods. Don't have much of a preference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make it more modular, in that if you find a way to create your own plates and edges, you can just pass it to the functions that then display it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sure. I think that makes sense then.

The dictionary of {PlateMeta : set[NodeMeta]} is a bit weird and hard to work with. i.e. set is not subscritable and looking up by PlateMeta key is a bit tricky.

I was thinking of having another object, Plate which would be:

@dataclass 
class Plate: 
    plate_meta: PlateMeta
    nodes: list[NodeMeta]

and that would be in the input to make_graph and make_networkx instead. Making the signature: (plates: list[Plate], edges: list[tuple[str, str]], ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does it make sense as a method still? Do you see model_to_graphviz taking this input as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lost track of the specific methods we're discussing. My low resolution guess was that once we have the plates / edges we can just pass them to a function that uses those to render graphviz or networkx graphs. Let me know if you were asking about something else or see a problem (or no point) with that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Let me push something up and you can give feedback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed.
If user has arbitrary list[Plate] and list[tuple[VarName, VarName]] then they can use make_graph or make_networkx in order to make the graphviz or networkx, respectively.
pm.model_to_graphviz and pm.model_to_networkx are still wrappers.
ModelGraph class can be used to create the plates and edges in the previous manner if desired with the get_plates and edges methods

self,
Expand All @@ -337,6 +415,7 @@ def make_graph(
figsize=None,
dpi=300,
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make graphviz Digraph of PyMC model

Expand All @@ -357,26 +436,35 @@ def make_graph(
node_formatters = update_node_formatters(node_formatters)

graph = graphviz.Digraph(self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
for plate in self.get_plates(var_names):
plate_meta = plate.meta
all_vars = plate.variables
if plate_meta.names or plate_meta.sizes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify? Could plate_meta be None for the scalar variables?

Copy link
Contributor Author

@williambdean williambdean Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic would still be needed somewhere. Likely in get_plates then.
How about having the __bool__ method for Plate class that does this logic.
Then would act like None and read like:

if plate_meta: # Truthy if sizes or names
    # plate_meta has sizes or names that are not empty tuples

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have that information when you defined the plate.meta no? Can't you do it immediately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I changed to have it happen in the get_plates methods. Scalars will have Plate(meta=None, variables=[...])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's enough to check for sizes? It is not possible for a plate to have names, but not sizes?

We should rename those to dim_names, and dim_lengths. And perhaps use None for dim_lengths for which we don't know the name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC, scalars should belong to a "Plate" with dim_names = (), and dim_lengths = ()?

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And now I understand your approach and I think it was better like you did. The __bool__ sounds fine as well!

Sorry I got confused by the names of the things

# must be preceded by 'cluster' to get a box around it
plate_label = create_plate_label(
all_vars[0].var.name, plate_meta, include_size=include_shape_size
)
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in all_var_names:
for var in all_vars:
self._make_node(
var_name, sub, formatting=formatting, node_formatters=node_formatters
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=sub.node,
)
# plate label goes bottom right
sub.attr(label=plate_label, labeljust="r", labelloc="b", style="rounded")
else:
for var_name in all_var_names:
for var in all_vars:
self._make_node(
var_name, graph, formatting=formatting, node_formatters=node_formatters
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graph.node,
)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that precede child rv nodes
for parent in parents:
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))
for child, parent in self.edges(var_names=var_names):
graph.edge(parent, child)

if save is not None:
width, height = (None, None) if figsize is None else figsize
Expand All @@ -397,6 +485,7 @@ def make_networkx(
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Make networkx Digraph of PyMC model

Expand All @@ -417,20 +506,24 @@ def make_networkx(
node_formatters = update_node_formatters(node_formatters)

graphnetwork = networkx.DiGraph(name=self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
for plate in self.get_plates(var_names):
plate_meta = plate.meta
all_vars = plate.variables
if plate_meta.names or plate_meta.sizes:
# # must be preceded by 'cluster' to get a box around it

plate_label = create_plate_label(
all_vars[0].var.name, plate_meta, include_size=include_shape_size
)
subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label)

for var_name in all_var_names:
for var in all_vars:
self._make_node(
var_name,
subgraphnetwork,
nx=True,
node=var,
node_formatters=node_formatters,
cluster="cluster" + plate_label,
formatting=formatting,
add_node=subgraphnetwork.add_node,
)
for sgn in subgraphnetwork.nodes:
networkx.set_node_attributes(
Expand All @@ -446,19 +539,17 @@ def make_networkx(
networkx.set_node_attributes(graphnetwork, node_data)
graphnetwork.graph["name"] = self.model.name
else:
for var_name in all_var_names:
for var in all_vars:
self._make_node(
var_name,
graphnetwork,
nx=True,
node=var,
formatting=formatting,
node_formatters=node_formatters,
add_node=graphnetwork.add_node,
)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that precede child rv nodes
for parent in parents:
graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&"))
for child, parents in self.edges(var_names=var_names):
graphnetwork.add_edge(parents, child)

return graphnetwork


Expand All @@ -468,6 +559,7 @@ def model_to_networkx(
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Produce a networkx Digraph from a PyMC model.

Expand All @@ -493,6 +585,8 @@ def model_to_networkx(
A dictionary mapping node types to functions that return a dictionary of node attributes.
Check out the networkx documentation for more information
how attributes are added to nodes: https://networkx.org/documentation/stable/reference/classes/generated/networkx.Graph.add_node.html
include_shape_size : bool
Include the shape size in the plate label. Default is True.

Examples
--------
Expand Down Expand Up @@ -541,7 +635,10 @@ def model_to_networkx(
)
model = pm.modelcontext(model)
return ModelGraph(model).make_networkx(
var_names=var_names, formatting=formatting, node_formatters=node_formatters
var_names=var_names,
formatting=formatting,
node_formatters=node_formatters,
include_shape_size=include_shape_size,
)


Expand All @@ -554,6 +651,7 @@ def model_to_graphviz(
figsize: tuple[int, int] | None = None,
dpi: int = 300,
node_formatters: NodeTypeFormatterMapping | None = None,
include_shape_size: bool = True,
):
"""Produce a graphviz Digraph from a PyMC model.

Expand Down Expand Up @@ -585,6 +683,8 @@ def model_to_graphviz(
A dictionary mapping node types to functions that return a dictionary of node attributes.
Check out graphviz documentation for more information on available
attributes. https://graphviz.org/docs/nodes/
include_shape_size : bool
Include the shape size in the plate label. Default is True.

Examples
--------
Expand Down Expand Up @@ -646,4 +746,5 @@ def model_to_graphviz(
figsize=figsize,
dpi=dpi,
node_formatters=node_formatters,
include_shape_size=include_shape_size,
)
Loading