Skip to content

Commit 7180472

Browse files
ferrinetaku-y
authored andcommitted
typos
1 parent 630e41e commit 7180472

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

pymc3/data.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __hash__(self):
9191

9292

9393
class Minibatch(tt.TensorVariable):
94-
"""Multidimensional minibatch
94+
"""Multidimensional minibatch that is pure TensorVariable
9595
9696
Parameters
9797
----------
@@ -122,12 +122,16 @@ class Minibatch(tt.TensorVariable):
122122
Consider we have data
123123
>>> data = np.random.rand(100, 100)
124124
125-
if we want 1d slice of size 10
125+
if we want 1d slice of size 10 we do
126126
>>> x = Minibatch(data, batch_size=10)
127+
128+
Note, that your data is casted to `floatX` if it is not integer type
129+
But you still can add dtype kwarg for :class:`Minibatch`
127130
128131
in case we want 10 sampled rows and columns
129-
[(size, seed), (size, seed)]
130-
>>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)])
132+
[(size, seed), (size, seed)] it is
133+
>>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)], dtype='int32')
134+
>>> assert str(x.dtype) == 'int32'
131135
132136
or simpler with default random seed = 42
133137
[size, size]
@@ -146,12 +150,21 @@ class Minibatch(tt.TensorVariable):
146150
>>> with model:
147151
... approx = pm.fit()
148152
149-
Notable thing is that Minibatch has `shared`, `minibatch`, attributes
153+
Notable thing is that :class:`Minibatch` has `shared`, `minibatch`, attributes
150154
you can call later
151155
>>> x.set_value(np.random.laplace(size=(100, 100)))
152156
153157
and minibatches will be then from new storage
154-
it directly affects `x.shared`.
158+
it directly affects `x.shared`.
159+
the same thing would be but less convenient
160+
>>> x.shared.set_value(pm.floatX(np.random.laplace(size=(100, 100))))
161+
162+
programmatic way to change storage is as following
163+
I import `partial` for simplicity
164+
>>> from functools import partial
165+
>>> datagen = partial(np.random.laplace, size=(100, 100))
166+
>>> x = Minibatch(datagen(), batch_size=100, update_shared_f=datagen)
167+
>>> x.update_shared()
155168
156169
To be more precise of how we get minibatch, here is a demo
157170
1) create shared variable
@@ -166,7 +179,7 @@ class Minibatch(tt.TensorVariable):
166179
That's done. So if you'll need some replacements in the graph
167180
>>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))
168181
169-
you are free to use a kind of this one as `x` as it is Theano Tensor
182+
you are free to use a kind of this one as `x` is regular Theano Tensor
170183
>>> replacements = {x: testdata}
171184
>>> node = x ** 2 # arbitrary expressions
172185
>>> rnode = theano.clone(node, replacements)
@@ -194,8 +207,11 @@ class Minibatch(tt.TensorVariable):
194207
@theano.configparser.change_flags(compute_test_value='raise')
195208
def __init__(self, data, batch_size=128, in_memory_size=None,
196209
random_seed=42, update_shared_f=None,
197-
broadcastable=None, name='Minibatch'):
198-
data = pm.smartfloatX(np.asarray(data))
210+
broadcastable=None, dtype=None, name='Minibatch'):
211+
if dtype is None:
212+
data = pm.smartfloatX(np.asarray(data))
213+
else:
214+
data = np.asarray(data, dtype)
199215
self._random_seed = random_seed
200216
in_memory_slc = self._to_slices(in_memory_size)
201217
self.batch_size = batch_size

0 commit comments

Comments
 (0)