Skip to content

Commit eb16420

Browse files
committed
Fix ValueGradFunction tests
1 parent e4012b7 commit eb16420

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

pymc3/tests/test_model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def test_no_extra(self):
265265
a.tag.test_value = np.zeros(3, dtype=a.dtype)
266266
a.dshape = (3,)
267267
a.dsize = 3
268-
f_grad = ValueGradFunction(a.sum(), [a], [], mode='FAST_COMPILE')
268+
f_grad = ValueGradFunction([a.sum()], [a], [], mode='FAST_COMPILE')
269269
assert f_grad.size == 3
270270

271271
def test_invalid_type(self):
@@ -274,7 +274,7 @@ def test_invalid_type(self):
274274
a.dshape = (3,)
275275
a.dsize = 3
276276
with pytest.raises(TypeError) as err:
277-
ValueGradFunction(a.sum(), [a], [], mode='FAST_COMPILE')
277+
ValueGradFunction([a.sum()], [a], [], mode='FAST_COMPILE')
278278
err.match('Invalid dtype')
279279

280280
def setUp(self):
@@ -303,7 +303,7 @@ def setUp(self):
303303
self.cost = extra1 * val1.sum() + val2.sum()
304304

305305
self.f_grad = ValueGradFunction(
306-
self.cost, [val1, val2], [extra1], mode='FAST_COMPILE')
306+
[self.cost], [val1, val2], [extra1], mode='FAST_COMPILE')
307307

308308
def test_extra_not_set(self):
309309
with pytest.raises(ValueError) as err:
@@ -396,10 +396,15 @@ def test_tempered_logp_dlogp():
396396
x = np.ones(func.size, dtype=func.dtype)
397397
assert func(x) == func_temp(x)
398398

399-
func_temp.set_weights(np.array([0.]))
399+
func_temp.set_weights(np.array([0.], dtype=func.dtype))
400+
func_temp_nograd.set_weights(np.array([0.], dtype=func.dtype))
400401
npt.assert_allclose(func(x)[0], 2 * func_temp(x)[0])
401402
npt.assert_allclose(func(x)[1], func_temp(x)[1])
402403

403-
func_temp.set_weights(np.array([0.5]))
404+
npt.assert_allclose(func_nograd(x), func(x)[0])
405+
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])
406+
407+
func_temp.set_weights(np.array([0.5], dtype=func.dtype))
408+
func_temp_nograd.set_weights(np.array([0.5], dtype=func.dtype))
404409
npt.assert_allclose(func(x)[0], 4 / 3 * func_temp(x)[0])
405410
npt.assert_allclose(func(x)[1], func_temp(x)[1])

0 commit comments

Comments
 (0)