@@ -377,10 +377,10 @@ def __init__(
377
377
compute_grads = True ,
378
378
** kwargs ,
379
379
):
380
- if extra_vars is None :
381
- extra_vars = []
380
+ if extra_vars_and_values is None :
381
+ extra_vars_and_values = {}
382
382
383
- names = [arg .name for arg in grad_vars + extra_vars ]
383
+ names = [arg .name for arg in grad_vars + list ( extra_vars_and_values . keys ()) ]
384
384
if any (name is None for name in names ):
385
385
raise ValueError ("Arguments must be named." )
386
386
if len (set (names )) != len (names ):
@@ -421,8 +421,8 @@ def __init__(
421
421
422
422
givens = []
423
423
self ._extra_vars_shared = {}
424
- for var in extra_vars :
425
- shared = aesara .shared (var . tag . test_value , var .name + "_shared__" )
424
+ for var , value in extra_vars_and_values . items () :
425
+ shared = aesara .shared (value , var .name + "_shared__" )
426
426
self ._extra_vars_shared [var .name ] = shared
427
427
givens .append ((var , shared ))
428
428
@@ -694,8 +694,13 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
694
694
costs = [self .logpt ]
695
695
696
696
input_vars = {i for i in graph_inputs (costs ) if not isinstance (i , Constant )}
697
- extra_vars = [var for var in self .free_RVs if var in input_vars ]
698
- return ValueGradFunction (costs , grad_vars , extra_vars , ** kwargs )
697
+ extra_vars = [getattr (var .tag , "value_var" , var ) for var in self .free_RVs ]
698
+ extra_vars_and_values = {
699
+ var : self .test_point [var .name ]
700
+ for var in extra_vars
701
+ if var in input_vars and var not in grad_vars
702
+ }
703
+ return ValueGradFunction (costs , grad_vars , extra_vars_and_values , ** kwargs )
699
704
700
705
@property
701
706
def logpt (self ):
0 commit comments