Skip to content

Commit ab41e0d

Browse files
brandonwillardmatteo-pallini
authored andcommitted
Fix extra_vars in call to ValueGradFunction from Model
1 parent 4231ee2 commit ab41e0d

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

pymc3/model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,10 @@ def __init__(
377377
compute_grads=True,
378378
**kwargs,
379379
):
380-
if extra_vars is None:
381-
extra_vars = []
380+
if extra_vars_and_values is None:
381+
extra_vars_and_values = {}
382382

383-
names = [arg.name for arg in grad_vars + extra_vars]
383+
names = [arg.name for arg in grad_vars + list(extra_vars_and_values.keys())]
384384
if any(name is None for name in names):
385385
raise ValueError("Arguments must be named.")
386386
if len(set(names)) != len(names):
@@ -421,8 +421,8 @@ def __init__(
421421

422422
givens = []
423423
self._extra_vars_shared = {}
424-
for var in extra_vars:
425-
shared = aesara.shared(var.tag.test_value, var.name + "_shared__")
424+
for var, value in extra_vars_and_values.items():
425+
shared = aesara.shared(value, var.name + "_shared__")
426426
self._extra_vars_shared[var.name] = shared
427427
givens.append((var, shared))
428428

@@ -694,8 +694,13 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
694694
costs = [self.logpt]
695695

696696
input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
697-
extra_vars = [var for var in self.free_RVs if var in input_vars]
698-
return ValueGradFunction(costs, grad_vars, extra_vars, **kwargs)
697+
extra_vars = [getattr(var.tag, "value_var", var) for var in self.free_RVs]
698+
extra_vars_and_values = {
699+
var: self.test_point[var.name]
700+
for var in extra_vars
701+
if var in input_vars and var not in grad_vars
702+
}
703+
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
699704

700705
@property
701706
def logpt(self):

pymc3/tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class TestValueGradFunction(unittest.TestCase):
223223
def test_no_extra(self):
224224
a = at.vector("a")
225225
a.tag.test_value = np.zeros(3, dtype=a.dtype)
226-
f_grad = ValueGradFunction([a.sum()], [a], [], mode="FAST_COMPILE")
226+
f_grad = ValueGradFunction([a.sum()], [a], {}, mode="FAST_COMPILE")
227227
assert f_grad._extra_vars == []
228228

229229
def test_invalid_type(self):

0 commit comments

Comments
 (0)