Skip to content

Commit 5d2f697

Browse files
committed
Add test for tempered logp_dlogp_function
1 parent 19323fe commit 5d2f697

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pymc3/tests/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,26 @@ def test_multiple_observed_rv():
380380
assert model['x'] == model['x']
381381
assert model['x'] in model.observed_RVs
382382
assert not model['x'] in model.vars
383+
384+
385+
def test_tempered_logp_dlogp():
386+
with pm.Model() as model:
387+
pm.Normal('x')
388+
pm.Normal('y', observed=1)
389+
390+
func = model.logp_dlogp_function()
391+
func.set_extra_values({})
392+
393+
func_temp = model.logp_dlogp_function(tempered=True)
394+
func_temp.set_extra_values({})
395+
396+
x = np.ones(func.size, dtype=func.dtype)
397+
assert func(x) == func_temp(x)
398+
399+
func_temp.set_weights(np.array([0.]))
400+
npt.assert_allclose(func(x)[0], 2 * func_temp(x)[0])
401+
npt.assert_allclose(func(x)[1], func_temp(x)[1])
402+
403+
func_temp.set_weights(np.array([0.5]))
404+
npt.assert_allclose(func(x)[0], 4 / 3 * func_temp(x)[0])
405+
npt.assert_allclose(func(x)[1], func_temp(x)[1])

0 commit comments

Comments
 (0)