@@ -597,6 +597,8 @@ class ValueGradFunction:
597
597
See `numpy.can_cast` for a description of the options.
598
598
Keep in mind that we cast the variables to the array *and*
599
599
back from the array dtype to the variable dtype.
600
+ compute_grads: bool, default=True
601
+ If False, return only the logp, not the gradient.
600
602
kwargs
601
603
Extra arguments are passed on to `theano.function`.
602
604
@@ -611,7 +613,15 @@ class ValueGradFunction:
611
613
"""
612
614
613
615
def __init__ (
614
- self , costs , grad_vars , extra_vars = None , dtype = None , casting = "no" , ** kwargs
616
+ self ,
617
+ costs ,
618
+ grad_vars ,
619
+ extra_vars = None ,
620
+ * ,
621
+ dtype = None ,
622
+ casting = "no" ,
623
+ compute_grads = True ,
624
+ ** kwargs
615
625
):
616
626
from .distributions import TensorType
617
627
@@ -651,13 +661,13 @@ def __init__(
651
661
for var in self ._grad_vars :
652
662
if not np .can_cast (var .dtype , self .dtype , casting ):
653
663
raise TypeError (
654
- "Invalid dtype for variable %s . Can not "
655
- "cast to %s with casting rule %s." % ( var . name , self . dtype , casting )
664
+ f "Invalid dtype for variable { var . name } . Can not "
665
+ f "cast to { self . dtype } with casting rule { casting } ."
656
666
)
657
667
if not np .issubdtype (var .dtype , np .floating ):
658
668
raise TypeError (
659
- "Invalid dtype for variable %s . Must be "
660
- "floating point but is %s." % ( var .name , var . dtype )
669
+ f "Invalid dtype for variable { var . name } . Must be "
670
+ f "floating point but is { var .dtype } ."
661
671
)
662
672
663
673
givens = []
@@ -677,13 +687,17 @@ def __init__(
677
687
self ._cost , grad_vars , self ._ordering .vmap
678
688
)
679
689
680
- grad = tt .grad (self ._cost_joined , self ._vars_joined )
681
- grad .name = "__grad"
690
+ if compute_grads :
691
+ grad = tt .grad (self ._cost_joined , self ._vars_joined )
692
+ grad .name = "__grad"
693
+ outputs = [self ._cost_joined , grad ]
694
+ else :
695
+ outputs = self ._cost_joined
682
696
683
697
inputs = [self ._vars_joined ]
684
698
685
699
self ._theano_function = theano .function (
686
- inputs , [ self . _cost_joined , grad ] , givens = givens , ** kwargs
700
+ inputs , outputs , givens = givens , ** kwargs
687
701
)
688
702
689
703
def set_weights (self , values ):
@@ -723,12 +737,12 @@ def __call__(self, array, grad_out=None, extra_vars=None):
723
737
else :
724
738
out = grad_out
725
739
726
- logp , dlogp = self ._theano_function (array )
740
+ output = self ._theano_function (array )
727
741
if grad_out is None :
728
- return logp , dlogp
742
+ return output
729
743
else :
730
- np .copyto (out , dlogp )
731
- return logp
744
+ np .copyto (out , output [ 1 ] )
745
+ return output [ 0 ]
732
746
733
747
@property
734
748
def profile (self ):
0 commit comments