Skip to content

Commit 7379e1e

Browse files
committed
add total gradient norm to VI
1 parent 0e54ce6 commit 7379e1e

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,12 +366,13 @@ def test_init_from_noize(self):
366366
(_advi, dict(start={}), None),
367367
(_fullrank_advi, dict(), None),
368368
(_svgd, dict(), None),
369-
('advi', dict(), None),
369+
('advi', dict(total_grad_norm_constraint=10), None),
370370
('advi->fullrank_advi', dict(frac=.1), None),
371371
('advi->fullrank_advi', dict(frac=1), ValueError),
372372
('fullrank_advi', dict(), None),
373-
('svgd', dict(), None),
373+
('svgd', dict(total_grad_norm_constraint=10), None),
374374
('svgd', dict(start={}), None),
375+
('asvgd', dict(start={}, total_grad_norm_constraint=10), None),
375376
('svgd', dict(local_rv={_model.free_RVs[0]: (0, 1)}), ValueError)
376377
]
377378
)

pymc3/variational/operators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __call__(self, z, **kwargs):
6363
grad *= pm.floatX(-1)
6464
grad = theano.clone(grad, {op.input_matrix: z})
6565
grad = tt.grad(None, params, known_grads={z: grad})
66-
grad = updates.total_norm_constraint(grad, 10)
6766
return grad
6867

6968

pymc3/variational/opvi.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def random(self, size=None):
9898
return self.op.approx.random(size)
9999

100100
def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
101-
more_obj_params=None, more_tf_params=None, more_updates=None, more_replacements=None):
101+
more_obj_params=None, more_tf_params=None, more_updates=None,
102+
more_replacements=None, total_grad_norm_constraint=None):
102103
"""Calculates gradients for objective function, test function and then
103104
constructs updates for optimization step
104105
@@ -120,27 +121,24 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
120121
Add custom updates to resulting updates
121122
more_replacements : `dict`
122123
Apply custom replacements before calculating gradients
124+
total_grad_norm_constraint : `float`
125+
Bounds gradient norm, prevents exploding gradient problem
123126
124127
Returns
125128
-------
126129
:class:`ObjectiveUpdates`
127130
"""
128-
if more_obj_params is None:
129-
more_obj_params = []
130-
if more_tf_params is None:
131-
more_tf_params = []
132131
if more_updates is None:
133132
more_updates = dict()
134-
if more_replacements is None:
135-
more_replacements = dict()
136133
resulting_updates = ObjectiveUpdates()
137134
if self.test_params:
138135
self.add_test_updates(
139136
resulting_updates,
140137
tf_n_mc=tf_n_mc,
141138
test_optimizer=test_optimizer,
142139
more_tf_params=more_tf_params,
143-
more_replacements=more_replacements
140+
more_replacements=more_replacements,
141+
total_grad_norm_constraint=total_grad_norm_constraint
144142
)
145143
else:
146144
if tf_n_mc is not None:
@@ -152,30 +150,47 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
152150
obj_n_mc=obj_n_mc,
153151
obj_optimizer=obj_optimizer,
154152
more_obj_params=more_obj_params,
155-
more_replacements=more_replacements
153+
more_replacements=more_replacements,
154+
total_grad_norm_constraint=total_grad_norm_constraint
156155
)
157156
resulting_updates.update(more_updates)
158157
return resulting_updates
159158

160159
def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
161-
more_tf_params=None, more_replacements=None):
160+
more_tf_params=None, more_replacements=None,
161+
total_grad_norm_constraint=None):
162+
if more_tf_params is None:
163+
more_tf_params = []
164+
if more_replacements is None:
165+
more_replacements = dict()
162166
tf_z = self.get_input(tf_n_mc)
163167
tf_target = self(tf_z, more_tf_params=more_tf_params)
164168
tf_target = theano.clone(tf_target, more_replacements, strict=False)
169+
grads = pm.updates.get_or_compute_grads(tf_target, self.obj_params + more_tf_params)
170+
if total_grad_norm_constraint is not None:
171+
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
165172
updates.update(
166173
test_optimizer(
167-
tf_target,
174+
grads,
168175
self.test_params +
169176
more_tf_params))
170177

171178
def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
172-
more_obj_params=None, more_replacements=None):
179+
more_obj_params=None, more_replacements=None,
180+
total_grad_norm_constraint=None):
181+
if more_obj_params is None:
182+
more_obj_params = []
183+
if more_replacements is None:
184+
more_replacements = dict()
173185
obj_z = self.get_input(obj_n_mc)
174186
obj_target = self(obj_z, more_obj_params=more_obj_params)
175187
obj_target = theano.clone(obj_target, more_replacements, strict=False)
188+
grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
189+
if total_grad_norm_constraint is not None:
190+
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
176191
updates.update(
177192
obj_optimizer(
178-
obj_target,
193+
grads,
179194
self.obj_params +
180195
more_obj_params))
181196
if self.op.RETURNS_LOSS:
@@ -189,8 +204,9 @@ def get_input(self, n_mc):
189204
def step_function(self, obj_n_mc=None, tf_n_mc=None,
190205
obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
191206
more_obj_params=None, more_tf_params=None,
192-
more_updates=None, more_replacements=None, score=False,
193-
fn_kwargs=None):
207+
more_updates=None, more_replacements=None,
208+
total_grad_norm_constraint=None,
209+
score=False, fn_kwargs=None):
194210
R"""Step function that should be called on each optimization step.
195211
196212
Generally it solves the following problem:
@@ -215,6 +231,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
215231
Add custom params for test function optimizer
216232
more_updates : `dict`
217233
Add custom updates to resulting updates
234+
total_grad_norm_constraint : `float`
235+
Bounds gradient norm, prevents exploding gradient problem
218236
score : `bool`
219237
calculate loss on each step? Defaults to False for speed
220238
fn_kwargs : `dict`
@@ -236,7 +254,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
236254
more_obj_params=more_obj_params,
237255
more_tf_params=more_tf_params,
238256
more_updates=more_updates,
239-
more_replacements=more_replacements)
257+
more_replacements=more_replacements,
258+
total_grad_norm_constraint=total_grad_norm_constraint)
240259
if score:
241260
step_fn = theano.function(
242261
[], updates.loss, updates=updates, **fn_kwargs)

0 commit comments

Comments
 (0)