Skip to content

Commit e443bac

Browse files
ferrinetaku-y
authored andcommitted
use strict floatX
1 parent 4f4b227 commit e443bac

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pymc3/data.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ class Minibatch(object):
174174
Used for training
175175
"""
176176
def __init__(self, data, batch_size=128, in_memory_size=None, random_seed=42, update_shared_f=None):
177+
data = pm.smartfloatX(np.asarray(data))
177178
self._random_seed = random_seed
178179
in_memory_slc = self._to_slices(in_memory_size)
179180
self.data = data
@@ -278,10 +279,18 @@ def check(t):
278279
return pm.theanof.ix_(*slc)
279280

280281
def update_shared(self):
281-
self.set_value(self.update_shared_f())
282+
self.set_value(np.asarray(self.update_shared_f(), self.dtype))
282283

283284
def set_value(self, value):
284-
self.shared.set_value(value)
285+
self.shared.set_value(np.asarray(value, self.dtype))
286+
287+
@property
288+
def dtype(self):
289+
return self.shared.dtype
290+
291+
@property
292+
def type(self):
293+
return self.shared.type
285294

286295
def __repr__(self):
287296
return '<Minibatch of %s>' % self.batch_size

pymc3/tests/test_minibatches.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616
def integers():
1717
i = 0
1818
while True:
19-
yield floatX(i)
19+
yield pm.floatX(i)
2020
i += 1
2121

2222

2323
def integers_ndim(ndim):
2424
i = 0
2525
while True:
26-
yield floatX(np.ones((2,) * ndim) * i)
26+
yield np.ones((2,) * ndim) * i
2727
i += 1
2828

2929

30+
@pytest.mark.usefixtures('strict_float32')
3031
class TestGenerator(object):
3132
def test_basic(self):
3233
generator = GeneratorAdapter(integers())
@@ -239,6 +240,7 @@ def test_free_rv(self):
239240
[1, 1]]))
240241

241242

243+
@pytest.mark.usefixtures('strict_float32')
242244
class TestMinibatch(object):
243245
data = np.random.rand(30, 10, 40, 10, 50)
244246

0 commit comments

Comments
 (0)