Skip to content

Commit 3c9384a

Browse files
ferrinetaku-y
authored andcommitted
Much better api
1 parent e443bac commit 3c9384a

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

pymc3/data.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import pymc3 as pm
8+
from pymc3 import theanof
89
import theano.tensor as tt
910
import theano
1011

@@ -149,7 +150,7 @@ def __next__(self):
149150
next = __next__
150151

151152

152-
class Minibatch(object):
153+
class Minibatch(tt.TensorVariable):
153154
"""Multidimensional minibatch
154155
155156
Parameters
@@ -173,18 +174,31 @@ class Minibatch(object):
173174
minibatch : minibatch tensor
174175
Used for training
175176
"""
176-
def __init__(self, data, batch_size=128, in_memory_size=None, random_seed=42, update_shared_f=None):
177+
@theanof.change_flags(compute_test_value='raise')
178+
def __init__(self, data, batch_size=128, in_memory_size=None,
179+
random_seed=42, update_shared_f=None,
180+
broadcastable=None, name='Minibatch'):
177181
data = pm.smartfloatX(np.asarray(data))
178182
self._random_seed = random_seed
179183
in_memory_slc = self._to_slices(in_memory_size)
180-
self.data = data
181184
self.batch_size = batch_size
182185
self.shared = theano.shared(data[in_memory_slc])
183186
self.update_shared_f = update_shared_f
184187
self.random_slc = self._to_random_slices(self.shared.shape, batch_size)
185-
self.minibatch = self.shared[self.random_slc]
188+
minibatch = self.shared[self.random_slc]
189+
if broadcastable is None:
190+
broadcastable = (False, ) * minibatch.ndim
191+
minibatch = tt.patternbroadcast(minibatch, broadcastable)
192+
self.minibatch = minibatch
193+
super(Minibatch, self).__init__(
194+
self.minibatch.type, None, None, name=name)
195+
theano.Apply(
196+
theano.compile.view_op,
197+
inputs=[self.minibatch], outputs=[self])
198+
self.tag.test_value = copy(self.minibatch.tag.test_value)
186199

187-
def rslice(self, total, size, seed):
200+
@staticmethod
201+
def rslice(total, size, seed):
188202
if size is None:
189203
return slice(None)
190204
elif isinstance(size, int):
@@ -284,13 +298,8 @@ def update_shared(self):
284298
def set_value(self, value):
285299
self.shared.set_value(np.asarray(value, self.dtype))
286300

287-
@property
288-
def dtype(self):
289-
return self.shared.dtype
290-
291-
@property
292-
def type(self):
293-
return self.shared.type
294-
295-
def __repr__(self):
296-
return '<Minibatch of %s>' % self.batch_size
301+
def clone(self):
302+
ret = self.type()
303+
ret.name = self.name
304+
ret.tag = copy(self.tag)
305+
return ret

pymc3/tests/test_minibatches.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,22 +246,30 @@ class TestMinibatch(object):
246246

247247
def test_1d(self):
248248
mb = pm.Minibatch(self.data, 20)
249-
assert mb.minibatch.eval().shape == (20, 10, 40, 10, 50)
249+
assert mb.eval().shape == (20, 10, 40, 10, 50)
250250

251251
def test_2d(self):
252252
with pytest.raises(TypeError):
253253
pm.Minibatch(self.data, (10, 5))
254254
mb = pm.Minibatch(self.data, [(10, 42), (4, 42)])
255-
assert mb.minibatch.eval().shape == (10, 4, 40, 10, 50)
255+
assert mb.eval().shape == (10, 4, 40, 10, 50)
256256

257257
def test_special1(self):
258258
mb = pm.Minibatch(self.data, [(10, 42), None, (4, 42)])
259-
assert mb.minibatch.eval().shape == (10, 10, 4, 10, 50)
259+
assert mb.eval().shape == (10, 10, 4, 10, 50)
260260

261261
def test_special2(self):
262262
mb = pm.Minibatch(self.data, [(10, 42), Ellipsis, (4, 42)])
263-
assert mb.minibatch.eval().shape == (10, 10, 40, 10, 4)
263+
assert mb.eval().shape == (10, 10, 40, 10, 4)
264264

265265
def test_special3(self):
266266
mb = pm.Minibatch(self.data, [(10, 42), None, Ellipsis, (4, 42)])
267-
assert mb.minibatch.eval().shape == (10, 10, 40, 10, 4)
267+
assert mb.eval().shape == (10, 10, 40, 10, 4)
268+
269+
def test_cloning_available(self):
270+
gop = pm.Minibatch(np.arange(100), 1)
271+
res = gop ** 2
272+
shared = theano.shared(np.array([10]))
273+
res1 = theano.clone(res, {gop: shared})
274+
f = theano.function([], res1)
275+
assert f() == np.array([100])

0 commit comments

Comments
 (0)