Skip to content

Commit 52a1b3c

Browse files
lucianopazricardoV94
authored andcommitted
get_vars_in_point_list only considers variables that are both in the model and in the point list
1 parent eaa51f3 commit 52a1b3c

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

pymc/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,9 @@ def __getitem__(self, key):
16011601
except KeyError:
16021602
raise e
16031603

1604+
def __contains__(self, key):
1605+
return key in self.named_vars or self.name_for(key) in self.named_vars
1606+
16041607
def compile_fn(
16051608
self,
16061609
outs: Sequence[Variable],

pymc/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,7 @@ def get_vars_in_point_list(trace, model):
15921592
names_in_trace = list(trace[0])
15931593
else:
15941594
names_in_trace = trace.varnames
1595-
vars_in_trace = [model[v] for v in names_in_trace]
1595+
vars_in_trace = [model[v] for v in names_in_trace if v in model]
15961596
return vars_in_trace
15971597

15981598

pymc/tests/test_sampling.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
_get_seeds_per_chain,
4949
assign_step_methods,
5050
compile_forward_sampling_function,
51+
get_vars_in_point_list,
5152
)
5253
from pymc.step_methods import (
5354
NUTS,
@@ -2628,3 +2629,24 @@ def test_sample(self):
26282629
np.testing.assert_allclose(
26292630
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
26302631
)
2632+
2633+
2634+
def test_get_vars_in_point_list():
2635+
with pm.Model() as modelA:
2636+
pm.Normal("a", 0, 1)
2637+
pm.Normal("b", 0, 1)
2638+
with pm.Model() as modelB:
2639+
a = pm.Normal("a", 0, 1)
2640+
pm.Normal("c", 0, 1)
2641+
2642+
point_list = [{"a": 0, "b": 0}]
2643+
vars_in_trace = get_vars_in_point_list(point_list, modelB)
2644+
assert set(vars_in_trace) == {a}
2645+
2646+
strace = pm.backends.NDArray(model=modelB, vars=modelA.free_RVs)
2647+
strace.setup(1, 1)
2648+
strace.values = point_list[0]
2649+
strace.draw_idx = 1
2650+
trace = MultiTrace([strace])
2651+
vars_in_trace = get_vars_in_point_list(trace, modelB)
2652+
assert set(vars_in_trace) == {a}

0 commit comments

Comments
 (0)