@@ -98,7 +98,8 @@ def random(self, size=None):
98
98
return self .op .approx .random (size )
99
99
100
100
def updates (self , obj_n_mc = None , tf_n_mc = None , obj_optimizer = adagrad_window , test_optimizer = adagrad_window ,
101
- more_obj_params = None , more_tf_params = None , more_updates = None , more_replacements = None ):
101
+ more_obj_params = None , more_tf_params = None , more_updates = None ,
102
+ more_replacements = None , total_grad_norm_constraint = None ):
102
103
"""Calculates gradients for objective function, test function and then
103
104
constructs updates for optimization step
104
105
@@ -120,27 +121,24 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
120
121
Add custom updates to resulting updates
121
122
more_replacements : `dict`
122
123
Apply custom replacements before calculating gradients
124
+ total_grad_norm_constraint : `float`
125
+ Bounds gradient norm, prevents exploding gradient problem
123
126
124
127
Returns
125
128
-------
126
129
:class:`ObjectiveUpdates`
127
130
"""
128
- if more_obj_params is None :
129
- more_obj_params = []
130
- if more_tf_params is None :
131
- more_tf_params = []
132
131
if more_updates is None :
133
132
more_updates = dict ()
134
- if more_replacements is None :
135
- more_replacements = dict ()
136
133
resulting_updates = ObjectiveUpdates ()
137
134
if self .test_params :
138
135
self .add_test_updates (
139
136
resulting_updates ,
140
137
tf_n_mc = tf_n_mc ,
141
138
test_optimizer = test_optimizer ,
142
139
more_tf_params = more_tf_params ,
143
- more_replacements = more_replacements
140
+ more_replacements = more_replacements ,
141
+ total_grad_norm_constraint = total_grad_norm_constraint
144
142
)
145
143
else :
146
144
if tf_n_mc is not None :
@@ -152,30 +150,47 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
152
150
obj_n_mc = obj_n_mc ,
153
151
obj_optimizer = obj_optimizer ,
154
152
more_obj_params = more_obj_params ,
155
- more_replacements = more_replacements
153
+ more_replacements = more_replacements ,
154
+ total_grad_norm_constraint = total_grad_norm_constraint
156
155
)
157
156
resulting_updates .update (more_updates )
158
157
return resulting_updates
159
158
160
159
def add_test_updates (self , updates , tf_n_mc = None , test_optimizer = adagrad_window ,
161
- more_tf_params = None , more_replacements = None ):
160
+ more_tf_params = None , more_replacements = None ,
161
+ total_grad_norm_constraint = None ):
162
+ if more_tf_params is None :
163
+ more_tf_params = []
164
+ if more_replacements is None :
165
+ more_replacements = dict ()
162
166
tf_z = self .get_input (tf_n_mc )
163
167
tf_target = self (tf_z , more_tf_params = more_tf_params )
164
168
tf_target = theano .clone (tf_target , more_replacements , strict = False )
169
+ grads = pm .updates .get_or_compute_grads (tf_target , self .obj_params + more_tf_params )
170
+ if total_grad_norm_constraint is not None :
171
+ grads = pm .total_norm_constraint (grads , total_grad_norm_constraint )
165
172
updates .update (
166
173
test_optimizer (
167
- tf_target ,
174
+ grads ,
168
175
self .test_params +
169
176
more_tf_params ))
170
177
171
178
def add_obj_updates (self , updates , obj_n_mc = None , obj_optimizer = adagrad_window ,
172
- more_obj_params = None , more_replacements = None ):
179
+ more_obj_params = None , more_replacements = None ,
180
+ total_grad_norm_constraint = None ):
181
+ if more_obj_params is None :
182
+ more_obj_params = []
183
+ if more_replacements is None :
184
+ more_replacements = dict ()
173
185
obj_z = self .get_input (obj_n_mc )
174
186
obj_target = self (obj_z , more_obj_params = more_obj_params )
175
187
obj_target = theano .clone (obj_target , more_replacements , strict = False )
188
+ grads = pm .updates .get_or_compute_grads (obj_target , self .obj_params + more_obj_params )
189
+ if total_grad_norm_constraint is not None :
190
+ grads = pm .total_norm_constraint (grads , total_grad_norm_constraint )
176
191
updates .update (
177
192
obj_optimizer (
178
- obj_target ,
193
+ grads ,
179
194
self .obj_params +
180
195
more_obj_params ))
181
196
if self .op .RETURNS_LOSS :
@@ -189,8 +204,9 @@ def get_input(self, n_mc):
189
204
def step_function (self , obj_n_mc = None , tf_n_mc = None ,
190
205
obj_optimizer = adagrad_window , test_optimizer = adagrad_window ,
191
206
more_obj_params = None , more_tf_params = None ,
192
- more_updates = None , more_replacements = None , score = False ,
193
- fn_kwargs = None ):
207
+ more_updates = None , more_replacements = None ,
208
+ total_grad_norm_constraint = None ,
209
+ score = False , fn_kwargs = None ):
194
210
R"""Step function that should be called on each optimization step.
195
211
196
212
Generally it solves the following problem:
@@ -215,6 +231,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
215
231
Add custom params for test function optimizer
216
232
more_updates : `dict`
217
233
Add custom updates to resulting updates
234
+ total_grad_norm_constraint : `float`
235
+ Bounds gradient norm, prevents exploding gradient problem
218
236
score : `bool`
219
237
calculate loss on each step? Defaults to False for speed
220
238
fn_kwargs : `dict`
@@ -236,7 +254,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
236
254
more_obj_params = more_obj_params ,
237
255
more_tf_params = more_tf_params ,
238
256
more_updates = more_updates ,
239
- more_replacements = more_replacements )
257
+ more_replacements = more_replacements ,
258
+ total_grad_norm_constraint = total_grad_norm_constraint )
240
259
if score :
241
260
step_fn = theano .function (
242
261
[], updates .loss , updates = updates , ** fn_kwargs )
0 commit comments