Skip to content

Commit 6279ed4

Browse files
ferrinetwiecki
authored andcommitted
add class based optimizers
1 parent 7d3fccb commit 6279ed4

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

pymc3/variational/updates.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
"""
9898

9999
from collections import OrderedDict
100-
100+
import functools
101101
import numpy as np
102102

103103
import theano
@@ -117,10 +117,28 @@
117117
"adam",
118118
"adamax",
119119
"norm_constraint",
120-
"total_norm_constraint"
120+
"total_norm_constraint",
121+
"Sgd",
122+
"Momentum",
123+
"NesterovMomentum",
124+
"Adagrad",
125+
"RMSProp",
126+
"AdaDelta",
127+
"AdaMax",
128+
"Adam",
121129
]
122130

123131

132+
class Optimizer(object):
133+
_opt = None
134+
135+
def __init__(self, *args, **kwargs):
136+
self.opt = functools.partial(self._opt, *args, **kwargs)
137+
138+
def __call__(self, loss_or_grads, params):
139+
return self.opt(loss_or_grads, params)
140+
141+
124142
def get_or_compute_grads(loss_or_grads, params):
125143
"""Helper function returning a list of gradients
126144
@@ -160,7 +178,7 @@ def get_or_compute_grads(loss_or_grads, params):
160178
return theano.grad(loss_or_grads, params)
161179

162180

163-
def sgd(loss_or_grads, params, learning_rate):
181+
def sgd(loss_or_grads, params, learning_rate=1e-3):
164182
"""Stochastic Gradient Descent (SGD) updates
165183
166184
Generates update expressions of the form:
@@ -190,6 +208,10 @@ def sgd(loss_or_grads, params, learning_rate):
190208
return updates
191209

192210

211+
class Sgd(Optimizer):
212+
_opt = sgd
213+
214+
193215
def apply_momentum(updates, params=None, momentum=0.9):
194216
"""Returns a modified update dictionary including momentum
195217
@@ -277,6 +299,10 @@ def momentum(loss_or_grads, params, learning_rate, momentum=0.9):
277299
return apply_momentum(updates, momentum=momentum)
278300

279301

302+
class Momentum(Optimizer):
303+
_opt = momentum
304+
305+
280306
def apply_nesterov_momentum(updates, params=None, momentum=0.9):
281307
"""Returns a modified update dictionary including Nesterov momentum
282308
@@ -331,7 +357,7 @@ def apply_nesterov_momentum(updates, params=None, momentum=0.9):
331357
return updates
332358

333359

334-
def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9):
360+
def nesterov_momentum(loss_or_grads, params, learning_rate=1e-3, momentum=0.9):
335361
"""Stochastic Gradient Descent (SGD) updates with Nesterov momentum
336362
337363
Generates update expressions of the form:
@@ -375,6 +401,10 @@ def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9):
375401
return apply_nesterov_momentum(updates, momentum=momentum)
376402

377403

404+
class NesterovMomentum(Optimizer):
405+
_opt = nesterov_momentum
406+
407+
378408
def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
379409
"""Adagrad updates
380410
@@ -434,6 +464,10 @@ def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
434464
return updates
435465

436466

467+
class Adagrad(Optimizer):
468+
_opt = adagrad
469+
470+
437471
def rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
438472
"""RMSProp updates
439473
@@ -495,6 +529,10 @@ def rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
495529
return updates
496530

497531

532+
class RMSProp(Optimizer):
533+
_opt = rmsprop
534+
535+
498536
def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95, epsilon=1e-6):
499537
""" Adadelta updates
500538
@@ -579,6 +617,10 @@ def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95, epsilon=1e-6):
579617
return updates
580618

581619

620+
class AdaDelta(Optimizer):
621+
_opt = adadelta
622+
623+
582624
def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
583625
beta2=0.999, epsilon=1e-8):
584626
"""Adam updates
@@ -646,6 +688,10 @@ def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
646688
return updates
647689

648690

691+
class Adam(Optimizer):
692+
_opt = adam
693+
694+
649695
def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
650696
beta2=0.999, epsilon=1e-8):
651697
"""Adamax updates
@@ -708,6 +754,10 @@ def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
708754
return updates
709755

710756

757+
class AdaMax(Optimizer):
758+
_opt = adamax
759+
760+
711761
def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7):
712762
"""Max weight norm constraints and gradient clipping
713763

0 commit comments

Comments
 (0)