Skip to content

Commit e4012b7

Browse files
committed
Add option to pass existing logp_dlogp_function to sampler
1 parent 5d2f697 commit e4012b7

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

pymc3/step_methods/arraystep.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,20 +236,25 @@ def link_population(self, population, chain_index):
236236

237237
class GradientSharedStep(BlockedStep):
238238
def __init__(self, vars, model=None, blocked=True,
239-
dtype=None, **theano_kwargs):
239+
dtype=None, logp_dlogp_func=None, **theano_kwargs):
240240
model = modelcontext(model)
241241
self.vars = vars
242242
self.blocked = blocked
243243

244-
func = model.logp_dlogp_function(
245-
vars, dtype=dtype, **theano_kwargs)
244+
if logp_dlogp_func is None:
245+
func = model.logp_dlogp_function(
246+
vars, dtype=dtype, **theano_kwargs)
247+
else:
248+
func = logp_dlogp_func
246249

247250
# handle edge case discovered in #2948
248251
try:
249252
func.set_extra_values(model.test_point)
250253
q = func.dict_to_array(model.test_point)
251254
logp, dlogp = func(q)
252255
except ValueError:
256+
if logp_dlogp_func is not None:
257+
raise
253258
theano_kwargs.update(mode='FAST_COMPILE')
254259
func = model.logp_dlogp_function(
255260
vars, dtype=dtype, **theano_kwargs)

0 commit comments

Comments
 (0)