Skip to content

Commit 5b2d8c5

Browse files
ferrinetaku-y
authored andcommitted
more decorations
1 parent 699344b commit 5b2d8c5

File tree

1 file changed

+44
-35
lines changed

1 file changed

+44
-35
lines changed

pymc3/data.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,26 @@ class Minibatch(tt.TensorVariable):
9595
9696
Parameters
9797
----------
98-
data : ndarray
98+
data : :class:`ndarray`
9999
initial data
100-
batch_size : int or List[int|tuple(size, random_seed)]
100+
batch_size : `int` or `List[int|tuple(size, random_seed)]`
101101
batch size for inference, random seed is needed
102102
for child random generators
103-
in_memory_size : int or List[int|slice|Ellipsis]
104-
data size for storing in theano.shared
105-
random_seed : int
103+
dtype : `str`
104+
cast data to specific type
105+
broadcastable : tuple[bool]
106+
change broadcastable pattern that defaults to `(False, ) * ndim`
107+
name : `str`
108+
name for tensor, defaults to "Minibatch"
109+
random_seed : `int`
106110
random seed that is used by default
107-
update_shared_f : callable
108-
returns np.ndarray that will be carefully
111+
update_shared_f : `callable`
112+
returns :class:`ndarray` that will be carefully
109113
stored to underlying shared variable
110114
you can use it to change source of
111115
minibatches programmatically
116+
in_memory_size : `int` or `List[int|slice|Ellipsis]`
117+
data size for storing in theano.shared
112118
113119
Attributes
114120
----------
@@ -126,18 +132,18 @@ class Minibatch(tt.TensorVariable):
126132
>>> x = Minibatch(data, batch_size=10)
127133
128134
Note, that your data is cast to `floatX` if it is not integer type
129-
But you still can add dtype kwarg for :class:`Minibatch`
135+
But you still can add `dtype` kwarg for :class:`Minibatch`
130136
131137
in case we want 10 sampled rows and columns
132-
[(size, seed), (size, seed)] it is
138+
`[(size, seed), (size, seed)]` it is
133139
>>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)], dtype='int32')
134140
>>> assert str(x.dtype) == 'int32'
135141
136142
or simpler with default random seed = 42
137-
[size, size]
143+
`[size, size]`
138144
>>> x = Minibatch(data, batch_size=[10, 10])
139145
140-
x is a regular TensorVariable that supports any math
146+
x is a regular :class:`TensorVariable` that supports any math
141147
>>> assert x.eval().shape == (10, 10)
142148
143149
You can pass it to your desired model
@@ -163,7 +169,7 @@ class Minibatch(tt.TensorVariable):
163169
I import `partial` for simplicity
164170
>>> from functools import partial
165171
>>> datagen = partial(np.random.laplace, size=(100, 100))
166-
>>> x = Minibatch(datagen(), batch_size=100, update_shared_f=datagen)
172+
>>> x = Minibatch(datagen(), batch_size=10, update_shared_f=datagen)
167173
>>> x.update_shared()
168174
169175
To be more concrete about how we get minibatch, here is a demo
@@ -176,55 +182,55 @@ class Minibatch(tt.TensorVariable):
176182
3) take that slice
177183
>>> minibatch = shared[ridx]
178184
179-
That's done. So if you'll need some replacements in the graph
185+
That's done. Next you can use this minibatch somewhere else.
186+
You can see that in implementation minibatch does not require
187+
fixed shape for shared variable. Feel free to use that if needed.
188+
189+
So if you'll need some replacements in the graph, e.g. change it to testdata
180190
>>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))
181191
182-
To change minibatch with static data you can create a dict with replacements
192+
You can change minibatch with static data you can create a dict with replacements
183193
>>> replacements = {x: testdata}
184194
>>> node = x ** 2 # arbitrary expressions
185195
>>> rnode = theano.clone(node, replacements)
186196
>>> assert (testdata ** 2 == rnode.eval()).all()
187-
197+
188198
To replace minibatch with it's shared variable
189-
instead of static `np.array` you should do
199+
instead of static :class:`ndarray` you should do
190200
>>> replacements = {x.minibatch: x.shared}
191201
>>> rnode = theano.clone(node, replacements)
192202
193203
For more complex slices some more code is needed that can seem not so clear
194-
They are
195204
>>> moredata = np.random.rand(10, 20, 30, 40, 50)
196-
197-
default total_size is then (10, 20, 30, 40, 50) but
198-
can be less verbose in some cases
205+
206+
default `total_size` that can be passed to `PyMC3` random node
207+
is then `(10, 20, 30, 40, 50)` but can be less verbose in some cases
199208
200209
1) Advanced indexing, `total_size = (10, Ellipsis, 50)`
201210
>>> x = Minibatch(moredata, [2, Ellipsis, 10])
202211
203212
We take slice only for the first and last dimension
204213
>>> assert x.eval().shape == (2, 20, 30, 40, 10)
205214
206-
2) skipping particular dimension, total_size = (10, None, 30)
215+
2) Skipping particular dimension, `total_size = (10, None, 30)`
207216
>>> x = Minibatch(moredata, [2, None, 20])
208217
>>> assert x.eval().shape == (2, 20, 20, 40, 50)
209218
210-
3) mixing that all, total_size = (10, None, 30, Ellipsis, 50)
219+
3) Mixing that all, `total_size = (10, None, 30, Ellipsis, 50)`
211220
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
212221
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
213222
"""
214223
@theano.configparser.change_flags(compute_test_value='raise')
215-
def __init__(self, data, batch_size=128, in_memory_size=None,
216-
random_seed=42, update_shared_f=None,
217-
broadcastable=None, dtype=None, name='Minibatch'):
224+
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',
225+
random_seed=42, update_shared_f=None, in_memory_size=None):
218226
if dtype is None:
219227
data = pm.smartfloatX(np.asarray(data))
220228
else:
221229
data = np.asarray(data, dtype)
222-
self._random_seed = random_seed
223-
in_memory_slc = self._to_slices(in_memory_size)
224-
self.batch_size = batch_size
230+
in_memory_slc = self.make_static_slices(in_memory_size)
225231
self.shared = theano.shared(data[in_memory_slc])
226232
self.update_shared_f = update_shared_f
227-
self.random_slc = self._to_random_slices(self.shared.shape, batch_size)
233+
self.random_slc = self.make_random_slices(self.shared.shape, batch_size, random_seed)
228234
minibatch = self.shared[self.random_slc]
229235
if broadcastable is None:
230236
broadcastable = (False, ) * minibatch.ndim
@@ -249,7 +255,7 @@ def rslice(total, size, seed):
249255
raise TypeError('Unrecognized size type, %r' % size)
250256

251257
@staticmethod
252-
def _to_slices(user_size):
258+
def make_static_slices(user_size):
253259
if user_size is None:
254260
return [Ellipsis]
255261
elif isinstance(user_size, int):
@@ -271,11 +277,12 @@ def _to_slices(user_size):
271277
else:
272278
raise TypeError('Unrecognized size type, %r' % user_size)
273279

274-
def _to_random_slices(self, in_memory_shape, batch_size):
280+
@classmethod
281+
def make_random_slices(cls, in_memory_shape, batch_size, default_random_seed):
275282
if batch_size is None:
276283
return [Ellipsis]
277284
elif isinstance(batch_size, int):
278-
slc = [self.rslice(in_memory_shape[0], batch_size, self._random_seed)]
285+
slc = [cls.rslice(in_memory_shape[0], batch_size, default_random_seed)]
279286
elif isinstance(batch_size, (list, tuple)):
280287
def check(t):
281288
if t is Ellipsis or t is None:
@@ -297,7 +304,7 @@ def check(t):
297304
'size and random seed are both ints, got %r' %
298305
batch_size)
299306
batch_size = [
300-
(i, self._random_seed) if isinstance(i, int) else i
307+
(i, default_random_seed) if isinstance(i, int) else i
301308
for i in batch_size
302309
]
303310
shape = in_memory_shape
@@ -326,10 +333,10 @@ def check(t):
326333
else:
327334
shp_end = np.asarray([])
328335
shp_begin = shape[:len(begin)]
329-
slc_begin = [self.rslice(shp_begin[i], t[0], t[1])
336+
slc_begin = [cls.rslice(shp_begin[i], t[0], t[1])
330337
if t is not None else tt.arange(shp_begin[i])
331338
for i, t in enumerate(begin)]
332-
slc_end = [self.rslice(shp_end[i], t[0], t[1])
339+
slc_end = [cls.rslice(shp_end[i], t[0], t[1])
333340
if t is not None else tt.arange(shp_end[i])
334341
for i, t in enumerate(end)]
335342
slc = slc_begin + mid + slc_end
@@ -339,6 +346,8 @@ def check(t):
339346
return pm.theanof.ix_(*slc)
340347

341348
def update_shared(self):
349+
if self.update_shared_f is None:
350+
raise NotImplementedError("No `update_shared_f` was provided to `__init__`")
342351
self.set_value(np.asarray(self.update_shared_f(), self.dtype))
343352

344353
def set_value(self, value):

0 commit comments

Comments
 (0)