Skip to content

Commit b81a9f7

Browse files
ferrinetwiecki
authored andcommitted
Variational Inference uses floatX (#2221)
* Variational Inference uses floatX * fix test that fardcode float32 * pylint fix * use floatX inside function * change scope * remove redundant floatX on python floats following @twiecki's suggestion
1 parent 73a21ae commit b81a9f7

File tree

12 files changed

+82
-54
lines changed

12 files changed

+82
-54
lines changed

pymc3/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def make_variable(self, gop, name=None):
6161
def __init__(self, generator):
6262
if not pm.vartypes.isgenerator(generator):
6363
raise TypeError('Object should be generator like')
64-
self.test_value = copy(next(generator))
64+
self.test_value = pm.smartfloatX(copy(next(generator)))
6565
# make pickling potentially possible
6666
self._yielded_test_value = False
6767
self.gen = generator
@@ -75,7 +75,7 @@ def __next__(self):
7575
self._yielded_test_value = True
7676
return self.test_value
7777
else:
78-
return copy(next(self.gen))
78+
return pm.smartfloatX(copy(next(self.gen)))
7979

8080
# python2 generator
8181
next = __next__

pymc3/distributions/dist_math.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from ..math import logdet as _logdet
1515
from pymc3.theanof import floatX
1616

17-
c = - 0.5 * np.log(2 * np.pi)
17+
f = floatX
18+
c = - .5 * np.log(2. * np.pi)
1819

1920

2021
def bound(logp, *conditions, **kwargs):
@@ -81,34 +82,34 @@ def std_cdf(x):
8182
"""
8283
Calculates the standard normal cumulative distribution function.
8384
"""
84-
return 0.5 + 0.5 * tt.erf(x / tt.sqrt(2.))
85+
return .5 + .5 * tt.erf(x / tt.sqrt(2.))
8586

8687

8788
def i0(x):
8889
"""
8990
Calculates the 0 order modified Bessel function of the first kind""
9091
"""
91-
return tt.switch(tt.lt(x, 5), 1 + x**2 / 4 + x**4 / 64 + x**6 / 2304 + x**8 / 147456
92-
+ x**10 / 14745600 + x**12 / 2123366400,
93-
np.e**x / (2 * np.pi * x)**0.5 * (1 + 1 / (8 * x) + 9 / (128 * x**2) + 225 / (3072 * x**3)
94-
+ 11025 / (98304 * x**4)))
92+
return tt.switch(tt.lt(x, 5), 1. + x**2 / 4. + x**4 / 64. + x**6 / 2304. + x**8 / 147456.
93+
+ x**10 / 14745600. + x**12 / 2123366400.,
94+
np.e**x / (2. * np.pi * x)**0.5 * (1. + 1. / (8. * x) + 9. / (128. * x**2) + 225. / (3072 * x**3)
95+
+ 11025. / (98304. * x**4)))
9596

9697

9798
def i1(x):
9899
"""
99100
Calculates the 1 order modified Bessel function of the first kind""
100101
"""
101-
return tt.switch(tt.lt(x, 5), x / 2 + x**3 / 16 + x**5 / 384 + x**7 / 18432 +
102-
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
103-
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
104-
+ 14175 / (98304 * x**4)))
102+
return tt.switch(tt.lt(x, 5), x / 2. + x**3 / 16. + x**5 / 384. + x**7 / 18432. +
103+
x**9 / 1474560. + x**11 / 176947200. + x**13 / 29727129600.,
104+
np.e**x / (2. * np.pi * x)**0.5 * (1. - 3. / (8. * x) + 15. / (128. * x**2) + 315. / (3072. * x**3)
105+
+ 14175. / (98304. * x**4)))
105106

106107

107108
def sd2rho(sd):
108109
"""
109110
`sd -> rho` theano converter
110111
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
111-
return tt.log(tt.exp(sd) - 1)
112+
return tt.log(tt.exp(sd) - 1.)
112113

113114

114115
def rho2sd(rho):
@@ -122,13 +123,15 @@ def log_normal(x, mean, **kwargs):
122123
"""
123124
Calculate logarithm of normal distribution at point `x`
124125
with given `mean` and `std`
126+
125127
Parameters
126128
----------
127129
x : Tensor
128130
point of evaluation
129131
mean : Tensor
130132
mean of normal distribution
131133
kwargs : one of parameters `{sd, tau, w, rho}`
134+
132135
Notes
133136
-----
134137
There are four variants for density parametrization.
@@ -143,7 +146,7 @@ def log_normal(x, mean, **kwargs):
143146
w = kwargs.get('w')
144147
rho = kwargs.get('rho')
145148
tau = kwargs.get('tau')
146-
eps = kwargs.get('eps', 0.0)
149+
eps = kwargs.get('eps', 0.)
147150
check = sum(map(lambda a: a is not None, [sd, w, rho, tau]))
148151
if check > 1:
149152
raise ValueError('more than one required kwarg is passed')
@@ -157,14 +160,15 @@ def log_normal(x, mean, **kwargs):
157160
std = rho2sd(rho)
158161
else:
159162
std = tau**(-1)
160-
std += eps
161-
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)
163+
std += f(eps)
164+
return f(c) - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2. * std ** 2)
162165

163166

164167
def log_normal_mv(x, mean, gpu_compat=False, **kwargs):
165168
"""
166169
Calculate logarithm of normal distribution at point `x`
167170
with given `mean` and `sigma` matrix
171+
168172
Parameters
169173
----------
170174
x : Tensor
@@ -173,8 +177,8 @@ def log_normal_mv(x, mean, gpu_compat=False, **kwargs):
173177
mean of normal distribution
174178
kwargs : one of parameters `{cov, tau, chol}`
175179
176-
Flags
177-
----------
180+
Other Parameters
181+
----------------
178182
gpu_compat : False, because LogDet is not GPU compatible yet.
179183
If this is set as true, the GPU compatible (but numerically unstable) log(det) is used.
180184
@@ -212,10 +216,10 @@ def logdet(m):
212216
T = tt.nlinalg.matrix_inverse(S)
213217
log_det = -logdet(S)
214218
delta = x - mean
215-
k = S.shape[0]
216-
result = k * tt.log(2 * np.pi) - log_det
219+
k = f(S.shape[0])
220+
result = k * tt.log(2. * np.pi) - log_det
217221
result += delta.dot(T).dot(delta)
218-
return -1 / 2. * result
222+
return -.5 * result
219223

220224

221225
def MvNormalLogp():
@@ -240,25 +244,25 @@ def MvNormalLogp():
240244
cholesky = Cholesky(nofail=True, lower=True)
241245

242246
n, k = delta.shape
243-
247+
n, k = f(n), f(k)
244248
chol_cov = cholesky(cov)
245249
diag = tt.nlinalg.diag(chol_cov)
246250
ok = tt.all(diag > 0)
247251

248252
chol_cov = tt.switch(ok, chol_cov, tt.fill(chol_cov, 1))
249253
delta_trans = solve_lower(chol_cov, delta.T).T
250254

251-
result = n * k * tt.log(2 * np.pi)
252-
result += 2.0 * n * tt.sum(tt.log(diag))
253-
result += (delta_trans ** 2).sum()
254-
result = -0.5 * result
255+
result = n * k * tt.log(f(2) * np.pi)
256+
result += f(2) * n * tt.sum(tt.log(diag))
257+
result += (delta_trans ** f(2)).sum()
258+
result = f(-.5) * result
255259
logp = tt.switch(ok, result, -np.inf)
256260

257261
def dlogp(inputs, gradients):
258262
g_logp, = gradients
259263
cov, delta = inputs
260264

261-
g_logp.tag.test_value = floatX(np.array(1.))
265+
g_logp.tag.test_value = floatX(1.)
262266
n, k = delta.shape
263267

264268
chol_cov = cholesky(cov)

pymc3/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def logsumexp(x, axis=None):
3333

3434

3535
def invlogit(x, eps=sys.float_info.epsilon):
36-
return (1 - 2 * eps) / (1 + tt.exp(-x)) + eps
36+
return (1. - 2. * eps) / (1. + tt.exp(-x)) + eps
3737

3838

3939
def logit(p):
40-
return tt.log(p / (1 - p))
40+
return tt.log(p / (floatX(1) - p))
4141

4242

4343
def flatten_list(tensors):
@@ -82,11 +82,11 @@ def __str__(self):
8282

8383

8484
def probit(p):
85-
return -sqrt(2) * erfcinv(2 * p)
85+
return -sqrt(2.) * erfcinv(2. * p)
8686

8787

8888
def invprobit(x):
89-
return 0.5 * erfc(-x / sqrt(2))
89+
return .5 * erfc(-x / sqrt(2.))
9090

9191

9292
def expand_packed_triangular(n, packed, lower=True, diagonal_only=False):

pymc3/model.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def scaling(self):
198198
denom = self.logp_elemwiset.shape[0]
199199
else:
200200
denom = 1
201-
coef = tt.as_tensor(total_size) / denom
202-
return coef
201+
coef = pm.floatX(tt.as_tensor(total_size)) / pm.floatX(denom)
202+
return pm.floatX(coef)
203203

204204

205205
class InitContextMeta(type):
@@ -840,19 +840,20 @@ def init_value(self):
840840
def pandas_to_array(data):
841841
if hasattr(data, 'values'): # pandas
842842
if data.isnull().any().any(): # missing values
843-
return np.ma.MaskedArray(data.values, data.isnull().values)
843+
ret = np.ma.MaskedArray(data.values, data.isnull().values)
844844
else:
845-
return data.values
845+
ret = data.values
846846
elif hasattr(data, 'mask'):
847-
return data
847+
ret = data
848848
elif isinstance(data, theano.gof.graph.Variable):
849-
return data
849+
ret = data
850850
elif sps.issparse(data):
851-
return data
851+
ret = data
852852
elif isgenerator(data):
853-
return generator(data)
853+
ret = generator(data)
854854
else:
855-
return np.asarray(data)
855+
ret = np.asarray(data)
856+
return pm.smartfloatX(ret)
856857

857858

858859
def as_tensor(data, name, model, distribution):

pymc3/tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,12 @@ def theano_config():
77
config = theano.configparser.change_flags(compute_test_value='raise')
88
with config:
99
yield
10+
11+
12+
@pytest.fixture(scope='function')
13+
def strict_float32():
14+
config = theano.configparser.change_flags(
15+
warn_float64='raise',
16+
floatX='float32')
17+
with config:
18+
yield

pymc3/tests/test_theanof.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
def integers():
1212
i = 0
1313
while True:
14-
yield np.float32(i)
14+
yield floatX(i)
1515
i += 1
1616

1717

1818
def integers_ndim(ndim):
1919
i = 0
2020
while True:
21-
yield np.ones((2,) * ndim) * i
21+
yield floatX(np.ones((2,) * ndim) * i)
2222
i += 1
2323

2424

@@ -47,15 +47,15 @@ def test_ndim(self):
4747
def test_cloning_available(self):
4848
gop = generator(integers())
4949
res = gop ** 2
50-
shared = theano.shared(np.float32(10))
50+
shared = theano.shared(floatX(10))
5151
res1 = theano.clone(res, {gop: shared})
5252
f = theano.function([], res1)
5353
assert f() == np.float32(100)
5454

5555
def test_default_value(self):
5656
def gen():
5757
for i in range(2):
58-
yield np.ones((10, 10)) * i
58+
yield floatX(np.ones((10, 10)) * i)
5959

6060
gop = generator(gen(), np.ones((10, 10)) * 10)
6161
f = theano.function([], gop)
@@ -68,7 +68,7 @@ def gen():
6868
def test_set_gen_and_exc(self):
6969
def gen():
7070
for i in range(2):
71-
yield np.ones((10, 10)) * i
71+
yield floatX(np.ones((10, 10)) * i)
7272

7373
gop = generator(gen())
7474
f = theano.function([], gop)

pymc3/tests/test_variational_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _test_aevb(self):
6666

6767

6868
class TestApproximates:
69+
@pytest.mark.usefixtures('strict_float32')
6970
class Base(SeededTest):
7071
inference = None
7172
NITER = 12000
@@ -202,7 +203,7 @@ def test_optimizer_minibatch_with_callback(self):
202203
def create_minibatch(data):
203204
while True:
204205
data = np.roll(data, 100, axis=0)
205-
yield data[:100]
206+
yield pm.floatX(data[:100])
206207

207208
minibatches = create_minibatch(data)
208209
with Model():

pymc3/theanof.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
'inputvars',
1919
'cont_inputs',
2020
'floatX',
21+
'smartfloatX',
2122
'jacobian',
2223
'CallableTensor',
2324
'join_nonshared_inputs',
@@ -67,6 +68,15 @@ def floatX(X):
6768
# Scalar passed
6869
return np.asarray(X, dtype=theano.config.floatX)
6970

71+
72+
def smartfloatX(x):
73+
"""
74+
Convert non int types to floatX
75+
"""
76+
if str(x.dtype).startswith('float'):
77+
x = floatX(x)
78+
return x
79+
7080
"""
7181
Theano derivative functions
7282
"""

pymc3/variational/approximations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,10 @@ def randidx(self, size=None):
317317
else:
318318
size = tuple(np.atleast_1d(size))
319319
return (self._rng
320-
.uniform(size=size, low=0.0, high=self.histogram.shape[0] - 1e-16)
321-
.astype('int64'))
320+
.uniform(size=size,
321+
low=pm.floatX(0),
322+
high=pm.floatX(self.histogram.shape[0]) - pm.floatX(1e-16))
323+
.astype('int32'))
322324

323325
def random_global(self, size=None, no_rand=False):
324326
theano_condition_is_here = isinstance(no_rand, tt.Variable)

pymc3/variational/operators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pymc3.variational.opvi import Operator, ObjectiveFunction, _warn_not_used
33
from pymc3.variational.stein import Stein
44
from pymc3.variational import updates
5+
import pymc3 as pm
56

67
__all__ = [
78
'KL',
@@ -59,7 +60,7 @@ def __call__(self, z, **kwargs):
5960
params = self.obj_params + kwargs['more_obj_params']
6061
else:
6162
params = self.test_params + kwargs['more_tf_params']
62-
grad *= -1
63+
grad *= pm.floatX(-1)
6364
grad = theano.clone(grad, {op.input_matrix: z})
6465
grad = tt.grad(None, params, known_grads={z: grad})
6566
grad = updates.total_norm_constraint(grad, 10)
@@ -103,7 +104,7 @@ def __init__(self, approx):
103104
def apply(self, f):
104105
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
105106
stein = Stein(self.approx, f, self.input_matrix)
106-
return -1 * stein.grad
107+
return pm.floatX(-1) * stein.grad
107108

108109

109110
class AKSD(KSD):

pymc3/variational/opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def normalizing_constant(self):
581581
# if not scale_cost_to_minibatch: t=1
582582
t = tt.switch(self.scale_cost_to_minibatch, t,
583583
tt.constant(1, dtype=t.dtype))
584-
return t
584+
return pm.floatX(t)
585585

586586
def _setup(self, **kwargs):
587587
pass

0 commit comments

Comments
 (0)