Skip to content

Commit b859845

Browse files
committed
pull out function get_constant_coords fn and add test
1 parent eeb4e12 commit b859845

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

pymc/sampling/forward.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@
7676
_log = logging.getLogger(__name__)
7777

7878

79+
def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set:
80+
"""Get the set of coords that have remained constant between the trace and model"""
81+
constant_coords = set()
82+
for dim, coord in trace_coords.items():
83+
current_coord = model.coords.get(dim, None)
84+
current_length = model.dim_lengths.get(dim, None)
85+
if hasattr(current_length, "get_value"):
86+
current_length = current_length.get_value()
87+
elif hasattr(current_length, "data"):
88+
current_length = current_length.data
89+
if (
90+
current_coord is not None
91+
and len(coord) == len(current_coord)
92+
and np.all(coord == current_coord)
93+
) or (
94+
# Coord was defined without values (only length)
95+
current_coord is None and len(coord) == current_length
96+
):
97+
constant_coords.add(dim)
98+
return constant_coords
99+
100+
79101
def get_vars_in_point_list(trace, model):
80102
"""Get the list of Variable instances in the model that have values stored in the trace."""
81103
if not isinstance(trace, MultiTrace):
@@ -792,22 +814,7 @@ def sample_posterior_predictive(
792814
stacklevel=2,
793815
)
794816

795-
constant_coords = set()
796-
for dim, coord in trace_coords.items():
797-
current_coord = model.coords.get(dim, None)
798-
current_length = model.dim_lengths.get(dim, None)
799-
if hasattr(current_length, 'eval'):
800-
current_length = current_length.eval()
801-
if (
802-
current_coord is not None
803-
and len(coord) == len(current_coord)
804-
and np.all(coord == current_coord)
805-
) or (
806-
# Coord was defined without values (only length)
807-
current_coord is None
808-
and len(coord) == current_length
809-
):
810-
constant_coords.add(dim)
817+
constant_coords = get_constant_coords(trace_coords, model)
811818

812819
if var_names is not None:
813820
vars_ = [model[x] for x in var_names]

tests/sampling/test_forward.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pymc.pytensorf import compile_pymc
3636
from pymc.sampling.forward import (
3737
compile_forward_sampling_function,
38+
get_constant_coords,
3839
get_vars_in_point_list,
3940
observed_dependent_deterministics,
4041
)
@@ -1669,6 +1670,17 @@ def test_Triangular(
16691670
assert prior["target"].shape == (prior_samples, *shape)
16701671

16711672

1673+
def test_get_constant_coords():
1674+
with pm.Model() as model:
1675+
model.add_coord("coord0", length=1)
1676+
1677+
trace_coords_same_len = {"coord0": np.array([0])}
1678+
assert "coord0" in get_constant_coords(trace_coords_same_len, model)
1679+
1680+
trace_coords_diff_len = {"coord0": np.array([0, 1])}
1681+
assert "coord0" not in get_constant_coords(trace_coords_diff_len, model)
1682+
1683+
16721684
def test_get_vars_in_point_list():
16731685
with pm.Model() as modelA:
16741686
pm.Normal("a", 0, 1)

0 commit comments

Comments
 (0)