@@ -265,7 +265,7 @@ def test_no_extra(self):
265
265
a .tag .test_value = np .zeros (3 , dtype = a .dtype )
266
266
a .dshape = (3 ,)
267
267
a .dsize = 3
268
- f_grad = ValueGradFunction (a .sum (), [a ], [], mode = 'FAST_COMPILE' )
268
+ f_grad = ValueGradFunction ([ a .sum ()] , [a ], [], mode = 'FAST_COMPILE' )
269
269
assert f_grad .size == 3
270
270
271
271
def test_invalid_type (self ):
@@ -274,7 +274,7 @@ def test_invalid_type(self):
274
274
a .dshape = (3 ,)
275
275
a .dsize = 3
276
276
with pytest .raises (TypeError ) as err :
277
- ValueGradFunction (a .sum (), [a ], [], mode = 'FAST_COMPILE' )
277
+ ValueGradFunction ([ a .sum ()] , [a ], [], mode = 'FAST_COMPILE' )
278
278
err .match ('Invalid dtype' )
279
279
280
280
def setUp (self ):
@@ -303,7 +303,7 @@ def setUp(self):
303
303
self .cost = extra1 * val1 .sum () + val2 .sum ()
304
304
305
305
self .f_grad = ValueGradFunction (
306
- self .cost , [val1 , val2 ], [extra1 ], mode = 'FAST_COMPILE' )
306
+ [ self .cost ] , [val1 , val2 ], [extra1 ], mode = 'FAST_COMPILE' )
307
307
308
308
def test_extra_not_set (self ):
309
309
with pytest .raises (ValueError ) as err :
@@ -396,10 +396,15 @@ def test_tempered_logp_dlogp():
396
396
x = np .ones (func .size , dtype = func .dtype )
397
397
assert func (x ) == func_temp (x )
398
398
399
- func_temp .set_weights (np .array ([0. ]))
399
+ func_temp .set_weights (np .array ([0. ], dtype = func .dtype ))
400
+ func_temp_nograd .set_weights (np .array ([0. ], dtype = func .dtype ))
400
401
npt .assert_allclose (func (x )[0 ], 2 * func_temp (x )[0 ])
401
402
npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
402
403
403
- func_temp .set_weights (np .array ([0.5 ]))
404
+ npt .assert_allclose (func_nograd (x ), func (x )[0 ])
405
+ npt .assert_allclose (func_temp_nograd (x ), func_temp (x )[0 ])
406
+
407
+ func_temp .set_weights (np .array ([0.5 ], dtype = func .dtype ))
408
+ func_temp_nograd .set_weights (np .array ([0.5 ], dtype = func .dtype ))
404
409
npt .assert_allclose (func (x )[0 ], 4 / 3 * func_temp (x )[0 ])
405
410
npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
0 commit comments