37
37
import theano .tensor as tt
38
38
39
39
import pymc3 as pm
40
- from .updates import adam
40
+ from .updates import adagrad_window
41
41
from ..distributions .dist_math import rho2sd , log_normal
42
42
from ..model import modelcontext , ArrayOrdering , DictToArrayBijection
43
43
from ..util import get_default_varnames
@@ -97,7 +97,7 @@ def random(self, size=None):
97
97
"""
98
98
return self .op .approx .random (size )
99
99
100
- def updates (self , obj_n_mc = None , tf_n_mc = None , obj_optimizer = adam , test_optimizer = adam ,
100
+ def updates (self , obj_n_mc = None , tf_n_mc = None , obj_optimizer = adagrad_window , test_optimizer = adagrad_window ,
101
101
more_obj_params = None , more_tf_params = None , more_updates = None , more_replacements = None ):
102
102
"""Calculates gradients for objective function, test function and then
103
103
constructs updates for optimization step
@@ -157,7 +157,7 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adam, test_optimize
157
157
resulting_updates .update (more_updates )
158
158
return resulting_updates
159
159
160
- def add_test_updates (self , updates , tf_n_mc = None , test_optimizer = adam ,
160
+ def add_test_updates (self , updates , tf_n_mc = None , test_optimizer = adagrad_window ,
161
161
more_tf_params = None , more_replacements = None ):
162
162
tf_z = self .get_input (tf_n_mc )
163
163
tf_target = self (tf_z , more_tf_params = more_tf_params )
@@ -168,7 +168,7 @@ def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adam,
168
168
self .test_params +
169
169
more_tf_params ))
170
170
171
- def add_obj_updates (self , updates , obj_n_mc = None , obj_optimizer = adam ,
171
+ def add_obj_updates (self , updates , obj_n_mc = None , obj_optimizer = adagrad_window ,
172
172
more_obj_params = None , more_replacements = None ):
173
173
obj_z = self .get_input (obj_n_mc )
174
174
obj_target = self (obj_z , more_obj_params = more_obj_params )
@@ -187,7 +187,7 @@ def get_input(self, n_mc):
187
187
@memoize
188
188
@change_flags (compute_test_value = 'off' )
189
189
def step_function (self , obj_n_mc = None , tf_n_mc = None ,
190
- obj_optimizer = adam , test_optimizer = adam ,
190
+ obj_optimizer = adagrad_window , test_optimizer = adagrad_window ,
191
191
more_obj_params = None , more_tf_params = None ,
192
192
more_updates = None , more_replacements = None , score = False ,
193
193
fn_kwargs = None ):
0 commit comments