|
97 | 97 | """
|
98 | 98 |
|
99 | 99 | from collections import OrderedDict
|
100 |
| - |
| 100 | +import functools |
101 | 101 | import numpy as np
|
102 | 102 |
|
103 | 103 | import theano
|
|
117 | 117 | "adam",
|
118 | 118 | "adamax",
|
119 | 119 | "norm_constraint",
|
120 |
| - "total_norm_constraint" |
| 120 | + "total_norm_constraint", |
| 121 | + "Sgd", |
| 122 | + "Momentum", |
| 123 | + "NesterovMomentum", |
| 124 | + "Adagrad", |
| 125 | + "RMSProp", |
| 126 | + "AdaDelta", |
| 127 | + "AdaMax", |
| 128 | + "Adam", |
121 | 129 | ]
|
122 | 130 |
|
123 | 131 |
|
| 132 | +class Optimizer(object): |
| 133 | + _opt = None |
| 134 | + |
| 135 | + def __init__(self, *args, **kwargs): |
| 136 | + self.opt = functools.partial(self._opt, *args, **kwargs) |
| 137 | + |
| 138 | + def __call__(self, loss_or_grads, params): |
| 139 | + return self.opt(loss_or_grads, params) |
| 140 | + |
| 141 | + |
124 | 142 | def get_or_compute_grads(loss_or_grads, params):
|
125 | 143 | """Helper function returning a list of gradients
|
126 | 144 |
|
@@ -160,7 +178,7 @@ def get_or_compute_grads(loss_or_grads, params):
|
160 | 178 | return theano.grad(loss_or_grads, params)
|
161 | 179 |
|
162 | 180 |
|
163 |
| -def sgd(loss_or_grads, params, learning_rate): |
| 181 | +def sgd(loss_or_grads, params, learning_rate=1e-3): |
164 | 182 | """Stochastic Gradient Descent (SGD) updates
|
165 | 183 |
|
166 | 184 | Generates update expressions of the form:
|
@@ -190,6 +208,10 @@ def sgd(loss_or_grads, params, learning_rate):
|
190 | 208 | return updates
|
191 | 209 |
|
192 | 210 |
|
| 211 | +class Sgd(Optimizer): |
| 212 | + _opt = sgd |
| 213 | + |
| 214 | + |
193 | 215 | def apply_momentum(updates, params=None, momentum=0.9):
|
194 | 216 | """Returns a modified update dictionary including momentum
|
195 | 217 |
|
@@ -277,6 +299,10 @@ def momentum(loss_or_grads, params, learning_rate, momentum=0.9):
|
277 | 299 | return apply_momentum(updates, momentum=momentum)
|
278 | 300 |
|
279 | 301 |
|
| 302 | +class Momentum(Optimizer): |
| 303 | + _opt = momentum |
| 304 | + |
| 305 | + |
280 | 306 | def apply_nesterov_momentum(updates, params=None, momentum=0.9):
|
281 | 307 | """Returns a modified update dictionary including Nesterov momentum
|
282 | 308 |
|
@@ -331,7 +357,7 @@ def apply_nesterov_momentum(updates, params=None, momentum=0.9):
|
331 | 357 | return updates
|
332 | 358 |
|
333 | 359 |
|
334 |
| -def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9): |
| 360 | +def nesterov_momentum(loss_or_grads, params, learning_rate=1e-3, momentum=0.9): |
335 | 361 | """Stochastic Gradient Descent (SGD) updates with Nesterov momentum
|
336 | 362 |
|
337 | 363 | Generates update expressions of the form:
|
@@ -375,6 +401,10 @@ def nesterov_momentum(loss_or_grads, params, learning_rate, momentum=0.9):
|
375 | 401 | return apply_nesterov_momentum(updates, momentum=momentum)
|
376 | 402 |
|
377 | 403 |
|
| 404 | +class NesterovMomentum(Optimizer): |
| 405 | + _opt = nesterov_momentum |
| 406 | + |
| 407 | + |
378 | 408 | def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
|
379 | 409 | """Adagrad updates
|
380 | 410 |
|
@@ -434,6 +464,10 @@ def adagrad(loss_or_grads, params, learning_rate=1.0, epsilon=1e-6):
|
434 | 464 | return updates
|
435 | 465 |
|
436 | 466 |
|
| 467 | +class Adagrad(Optimizer): |
| 468 | + _opt = adagrad |
| 469 | + |
| 470 | + |
437 | 471 | def rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
|
438 | 472 | """RMSProp updates
|
439 | 473 |
|
@@ -495,6 +529,10 @@ def rmsprop(loss_or_grads, params, learning_rate=1.0, rho=0.9, epsilon=1e-6):
|
495 | 529 | return updates
|
496 | 530 |
|
497 | 531 |
|
| 532 | +class RMSProp(Optimizer): |
| 533 | + _opt = rmsprop |
| 534 | + |
| 535 | + |
498 | 536 | def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95, epsilon=1e-6):
|
499 | 537 | """ Adadelta updates
|
500 | 538 |
|
@@ -579,6 +617,10 @@ def adadelta(loss_or_grads, params, learning_rate=1.0, rho=0.95, epsilon=1e-6):
|
579 | 617 | return updates
|
580 | 618 |
|
581 | 619 |
|
| 620 | +class AdaDelta(Optimizer): |
| 621 | + _opt = adadelta |
| 622 | + |
| 623 | + |
582 | 624 | def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
|
583 | 625 | beta2=0.999, epsilon=1e-8):
|
584 | 626 | """Adam updates
|
@@ -646,6 +688,10 @@ def adam(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
|
646 | 688 | return updates
|
647 | 689 |
|
648 | 690 |
|
| 691 | +class Adam(Optimizer): |
| 692 | + _opt = adam |
| 693 | + |
| 694 | + |
649 | 695 | def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
|
650 | 696 | beta2=0.999, epsilon=1e-8):
|
651 | 697 | """Adamax updates
|
@@ -708,6 +754,10 @@ def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
|
708 | 754 | return updates
|
709 | 755 |
|
710 | 756 |
|
| 757 | +class AdaMax(Optimizer): |
| 758 | + _opt = adamax |
| 759 | + |
| 760 | + |
711 | 761 | def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7):
|
712 | 762 | """Max weight norm constraints and gradient clipping
|
713 | 763 |
|
|
0 commit comments