Skip to content

Commit 630e41e

Browse files
ferrinetaku-y
authored andcommitted
add docstring
1 parent 2020776 commit 630e41e

File tree

2 files changed

+96
-9
lines changed

2 files changed

+96
-9
lines changed

pymc3/data.py

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,99 @@ class Minibatch(tt.TensorVariable):
9797
----------
9898
data : ndarray
9999
initial data
100-
batch_size : int or List[tuple(size, random_seed)]
100+
batch_size : int or List[int|tuple(size, random_seed)]
101101
batch size for inference, random seed is needed
102102
for child random generators
103-
in_memory_size : int or List[int,slice,Ellipsis]
103+
in_memory_size : int or List[int|slice|Ellipsis]
104104
data size for storing in theano.shared
105105
random_seed : int
106-
random seed that is used for 1d random slice
106+
random seed that is used by default
107107
update_shared_f : callable
108-
gets in_memory_shape and returns np.ndarray
108+
returns np.ndarray that will be carefully
109+
stored to underlying shared variable
110+
you can use it to change source of
111+
minibatches programmatically
109112
110113
Attributes
111114
----------
112115
shared : shared tensor
113116
Used for storing data
114117
minibatch : minibatch tensor
115118
Used for training
119+
120+
Examples
121+
--------
122+
Consider we have data
123+
>>> data = np.random.rand(100, 100)
124+
125+
if we want 1d slice of size 10
126+
>>> x = Minibatch(data, batch_size=10)
127+
128+
in case we want 10 sampled rows and columns
129+
[(size, seed), (size, seed)]
130+
>>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)])
131+
132+
or simpler with default random seed = 42
133+
[size, size]
134+
>>> x = Minibatch(data, batch_size=[10, 10])
135+
136+
x is a regular TensorVariable that supports any math
137+
>>> assert x.eval().shape == (10, 10)
138+
139+
You can pass it to your desired model
140+
>>> with pm.Model() as model:
141+
... mu = pm.Flat('mu')
142+
... sd = pm.HalfNormal('sd')
143+
... lik = pm.Normal('lik', mu, sd, observed=x)
144+
145+
Then you can perform regular Variational Inference out of the box
146+
>>> with model:
147+
... approx = pm.fit()
148+
149+
Notable thing is that Minibatch has `shared`, `minibatch`, attributes
150+
you can call later
151+
>>> x.set_value(np.random.laplace(size=(100, 100)))
152+
153+
and minibatches will be then from new storage
154+
it directly affects `x.shared`.
155+
156+
To be more precise of how we get minibatch, here is a demo
157+
1) create shared variable
158+
>>> shared = theano.shared(data)
159+
160+
2) create random slice of size 10
161+
>>> ridx = pm.tt_rng().uniform(size=(10,), low=0, high=data.shape[0]-1e-10).astype('int64')
162+
163+
3) take that slice
164+
>>> minibatch = shared[ridx]
165+
166+
That's done. So if you'll need some replacements in the graph
167+
>>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))
168+
169+
you are free to use a kind of this one as `x` as it is Theano Tensor
170+
>>> replacements = {x: testdata}
171+
>>> node = x ** 2 # arbitrary expressions
172+
>>> rnode = theano.clone(node, replacements)
173+
>>> assert (testdata ** 2 == rnode.eval()).all()
174+
175+
For more complex slices some more code is needed that can seem not so clear
176+
They are
177+
>>> moredata = np.random.rand(10, 20, 30, 40, 50)
178+
179+
1) Advanced indexing
180+
>>> x = Minibatch(moredata, [2, Ellipsis, 10])
181+
182+
We take slice only for the first and last dimension
183+
>>> assert x.eval().shape == (2, 20, 30, 40, 10)
184+
185+
2) skipping particular dimension
186+
>>> x = Minibatch(moredata, [2, None, 20])
187+
>>> assert x.eval().shape == (2, 20, 20, 40, 50)
188+
189+
3) mixing that all
190+
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
191+
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
192+
116193
"""
117194
@theano.configparser.change_flags(compute_test_value='raise')
118195
def __init__(self, data, batch_size=128, in_memory_size=None,
@@ -181,19 +258,25 @@ def check(t):
181258
if t is Ellipsis or t is None:
182259
return True
183260
else:
184-
if not isinstance(t, (tuple, list)):
185-
return False
186-
else:
261+
if isinstance(t, (tuple, list)):
187262
if not len(t) == 2:
188263
return False
189264
else:
190265
return isinstance(t[0], int) and isinstance(t[1], int)
191-
266+
elif isinstance(t, int):
267+
return True
268+
else:
269+
return False
270+
# end check definition
192271
if not all(check(t) for t in batch_size):
193272
raise TypeError('Unrecognized `batch_size` type, expected '
194-
'int or List[tuple(size, random_seed)] where '
273+
'int or List[int|tuple(size, random_seed)] where '
195274
'size and random seed are both ints, got %r' %
196275
batch_size)
276+
batch_size = [
277+
(i, self._random_seed) if isinstance(i, int) else i
278+
for i in batch_size
279+
]
197280
shape = in_memory_shape
198281
if Ellipsis in batch_size:
199282
sep = batch_size.index(Ellipsis)

pymc3/tests/test_minibatches.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ def test_special3(self):
263263
mb = pm.Minibatch(self.data, [(10, 42), None, Ellipsis, (4, 42)])
264264
assert mb.eval().shape == (10, 10, 40, 10, 4)
265265

266+
def test_special4(self):
267+
mb = pm.Minibatch(self.data, [10, None, Ellipsis, (4, 42)])
268+
assert mb.eval().shape == (10, 10, 40, 10, 4)
269+
266270
def test_cloning_available(self):
267271
gop = pm.Minibatch(np.arange(100), 1)
268272
res = gop ** 2

0 commit comments

Comments
 (0)