File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -380,3 +380,26 @@ def test_multiple_observed_rv():
380
380
assert model ['x' ] == model ['x' ]
381
381
assert model ['x' ] in model .observed_RVs
382
382
assert not model ['x' ] in model .vars
383
+
384
+
385
+ def test_tempered_logp_dlogp ():
386
+ with pm .Model () as model :
387
+ pm .Normal ('x' )
388
+ pm .Normal ('y' , observed = 1 )
389
+
390
+ func = model .logp_dlogp_function ()
391
+ func .set_extra_values ({})
392
+
393
+ func_temp = model .logp_dlogp_function (tempered = True )
394
+ func_temp .set_extra_values ({})
395
+
396
+ x = np .ones (func .size , dtype = func .dtype )
397
+ assert func (x ) == func_temp (x )
398
+
399
+ func_temp .set_weights (np .array ([0. ]))
400
+ npt .assert_allclose (func (x )[0 ], 2 * func_temp (x )[0 ])
401
+ npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
402
+
403
+ func_temp .set_weights (np .array ([0.5 ]))
404
+ npt .assert_allclose (func (x )[0 ], 4 / 3 * func_temp (x )[0 ])
405
+ npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
You can’t perform that action at this time.
0 commit comments