@@ -581,8 +581,9 @@ class ValueGradFunction:
581
581
582
582
Parameters
583
583
----------
584
- cost: theano variable
585
- The value that we compute with its gradient.
584
+ costs: list of theano variables
585
+ We compute the weighted sum of the specified theano values, and the gradient
586
+ of that sum. The weights can be specified with `ValueGradFunction.set_weights`.
586
587
grad_vars: list of named theano variables or None
587
588
The arguments with respect to which the gradient is computed.
588
589
extra_vars: list of named theano variables or None
@@ -610,7 +611,7 @@ class ValueGradFunction:
610
611
"""
611
612
612
613
def __init__ (
613
- self , cost , grad_vars , extra_vars = None , dtype = None , casting = "no" , ** kwargs
614
+ self , costs , grad_vars , extra_vars = None , dtype = None , casting = "no" , ** kwargs
614
615
):
615
616
from .distributions import TensorType
616
617
@@ -623,19 +624,30 @@ def __init__(
623
624
if len (set (names )) != len (names ):
624
625
raise ValueError ("Names of the arguments are not unique." )
625
626
626
- if cost .ndim > 0 :
627
- raise ValueError ("Cost must be a scalar." )
628
-
629
627
self ._grad_vars = grad_vars
630
628
self ._extra_vars = extra_vars
631
629
self ._extra_var_names = {var .name for var in extra_vars }
630
+
631
+ if dtype is None :
632
+ dtype = theano .config .floatX
633
+ self .dtype = dtype
634
+
635
+ self ._n_costs = len (costs )
636
+ if self ._n_costs == 0 :
637
+ raise ValueError ("At least one cost is required." )
638
+ weights = np .ones (self ._n_costs - 1 , dtype = self .dtype )
639
+ self ._weights = theano .shared (weights , "__weights" )
640
+
641
+ cost = costs [0 ]
642
+ for i , val in enumerate (costs [1 :]):
643
+ if cost .ndim > 0 or val .ndim > 0 :
644
+ raise ValueError ("All costs must be scalar." )
645
+ cost = cost + self ._weights [i ] * val
646
+
632
647
self ._cost = cost
633
648
self ._ordering = ArrayOrdering (grad_vars )
634
649
self .size = self ._ordering .size
635
650
self ._extra_are_set = False
636
- if dtype is None :
637
- dtype = theano .config .floatX
638
- self .dtype = dtype
639
651
for var in self ._grad_vars :
640
652
if not np .can_cast (var .dtype , self .dtype , casting ):
641
653
raise TypeError (
@@ -674,6 +686,11 @@ def __init__(
674
686
inputs , [self ._cost_joined , grad ], givens = givens , ** kwargs
675
687
)
676
688
689
+ def set_weights (self , values ):
690
+ if values .shape != (self ._n_costs - 1 ,):
691
+ raise ValueError ("Invalid shape. Must be (n_costs - 1,)." )
692
+ self ._weights .set_value (values )
693
+
677
694
def set_extra_values (self , extra_vars ):
678
695
self ._extra_are_set = True
679
696
for var in self ._extra_vars :
@@ -940,7 +957,18 @@ def dlogp_array(self):
940
957
vars = inputvars (self .cont_vars )
941
958
return self .bijection .mapf (self .fastdlogp (vars ))
942
959
943
- def logp_dlogp_function (self , grad_vars = None , ** kwargs ):
960
+ def logp_dlogp_function (self , grad_vars = None , tempered = False , ** kwargs ):
961
+ """Compile a theano function that computes logp and gradient.
962
+
963
+ Parameters
964
+ ----------
965
+ grad_vars: list of random variables, optional
966
+ Compute the gradient with respect to those variables. If None,
967
+ use all free random variables of this model.
968
+ tempered: bool
969
+ Compute the tempered logp `free_logp + alpha * observed_logp`.
970
+ `alpha` can be changed using `ValueGradFunction.set_weights([alpha])`.
971
+ """
944
972
if grad_vars is None :
945
973
grad_vars = list (typefilter (self .free_RVs , continuous_types ))
946
974
else :
@@ -949,9 +977,22 @@ def logp_dlogp_function(self, grad_vars=None, **kwargs):
949
977
raise ValueError (
950
978
"Can only compute the gradient of " "continuous types: %s" % var
951
979
)
980
+
981
+ if tempered :
982
+ with self :
983
+ free_RVs_logp = tt .sum ([
984
+ tt .sum (var .logpt ) for var in self .free_RVs + self .potentials
985
+ ])
986
+ observed_RVs_logp = tt .sum ([
987
+ tt .sum (var .logpt ) for var in self .observed_RVs
988
+ ])
989
+
990
+ costs = [free_RVs_logp , observed_RVs_logp ]
991
+ else :
992
+ costs = [self .logpt ]
952
993
varnames = [var .name for var in grad_vars ]
953
994
extra_vars = [var for var in self .free_RVs if var .name not in varnames ]
954
- return ValueGradFunction (self . logpt , grad_vars , extra_vars , ** kwargs )
995
+ return ValueGradFunction (costs , grad_vars , extra_vars , ** kwargs )
955
996
956
997
@property
957
998
def logpt (self ):
@@ -1050,7 +1091,7 @@ def add_coords(self, coords):
1050
1091
return
1051
1092
1052
1093
for name in coords :
1053
- if name in { "draw" , "chain" }:
1094
+ if name in {"draw" , "chain" }:
1054
1095
raise ValueError (
1055
1096
"Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs."
1056
1097
)
0 commit comments