Skip to content

Commit 73a21ae

Browse files
twieckiferrine
authored andcommitted
Change default optimizer in OPVI to adagrad_window (#2218)
1 parent c306452 commit 73a21ae

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc3/variational/opvi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
import theano.tensor as tt
3838

3939
import pymc3 as pm
40-
from .updates import adam
40+
from .updates import adagrad_window
4141
from ..distributions.dist_math import rho2sd, log_normal
4242
from ..model import modelcontext, ArrayOrdering, DictToArrayBijection
4343
from ..util import get_default_varnames
@@ -97,7 +97,7 @@ def random(self, size=None):
9797
"""
9898
return self.op.approx.random(size)
9999

100-
def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adam, test_optimizer=adam,
100+
def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
101101
more_obj_params=None, more_tf_params=None, more_updates=None, more_replacements=None):
102102
"""Calculates gradients for objective function, test function and then
103103
constructs updates for optimization step
@@ -157,7 +157,7 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adam, test_optimize
157157
resulting_updates.update(more_updates)
158158
return resulting_updates
159159

160-
def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adam,
160+
def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
161161
more_tf_params=None, more_replacements=None):
162162
tf_z = self.get_input(tf_n_mc)
163163
tf_target = self(tf_z, more_tf_params=more_tf_params)
@@ -168,7 +168,7 @@ def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adam,
168168
self.test_params +
169169
more_tf_params))
170170

171-
def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adam,
171+
def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
172172
more_obj_params=None, more_replacements=None):
173173
obj_z = self.get_input(obj_n_mc)
174174
obj_target = self(obj_z, more_obj_params=more_obj_params)
@@ -187,7 +187,7 @@ def get_input(self, n_mc):
187187
@memoize
188188
@change_flags(compute_test_value='off')
189189
def step_function(self, obj_n_mc=None, tf_n_mc=None,
190-
obj_optimizer=adam, test_optimizer=adam,
190+
obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
191191
more_obj_params=None, more_tf_params=None,
192192
more_updates=None, more_replacements=None, score=False,
193193
fn_kwargs=None):

0 commit comments

Comments
 (0)