@@ -95,20 +95,26 @@ class Minibatch(tt.TensorVariable):
95
95
96
96
Parameters
97
97
----------
98
- data : ndarray
98
+ data : :class:` ndarray`
99
99
initial data
100
- batch_size : int or List[int|tuple(size, random_seed)]
100
+ batch_size : ` int` or ` List[int|tuple(size, random_seed)]`
101
101
batch size for inference, random seed is needed
102
102
for child random generators
103
- in_memory_size : int or List[int|slice|Ellipsis]
104
- data size for storing in theano.shared
105
- random_seed : int
103
+ dtype : `str`
104
+ cast data to specific type
105
+ broadcastable : tuple[bool]
106
+ change broadcastable pattern that defaults to `(False, ) * ndim`
107
+ name : `str`
108
+ name for tensor, defaults to "Minibatch"
109
+ random_seed : `int`
106
110
random seed that is used by default
107
- update_shared_f : callable
108
- returns np. ndarray that will be carefully
111
+ update_shared_f : ` callable`
112
+ returns :class:` ndarray` that will be carefully
109
113
stored to underlying shared variable
110
114
you can use it to change source of
111
115
minibatches programmatically
116
+ in_memory_size : `int` or `List[int|slice|Ellipsis]`
117
+ data size for storing in theano.shared
112
118
113
119
Attributes
114
120
----------
@@ -126,18 +132,18 @@ class Minibatch(tt.TensorVariable):
126
132
>>> x = Minibatch(data, batch_size=10)
127
133
128
134
Note, that your data is cast to `floatX` if it is not integer type
129
- But you still can add dtype kwarg for :class:`Minibatch`
135
+ But you still can add ` dtype` kwarg for :class:`Minibatch`
130
136
131
137
in case we want 10 sampled rows and columns
132
- [(size, seed), (size, seed)] it is
138
+ ` [(size, seed), (size, seed)]` it is
133
139
>>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)], dtype='int32')
134
140
>>> assert str(x.dtype) == 'int32'
135
141
136
142
or simpler with default random seed = 42
137
- [size, size]
143
+ ` [size, size]`
138
144
>>> x = Minibatch(data, batch_size=[10, 10])
139
145
140
- x is a regular TensorVariable that supports any math
146
+ x is a regular :class:` TensorVariable` that supports any math
141
147
>>> assert x.eval().shape == (10, 10)
142
148
143
149
You can pass it to your desired model
@@ -163,7 +169,7 @@ class Minibatch(tt.TensorVariable):
163
169
I import `partial` for simplicity
164
170
>>> from functools import partial
165
171
>>> datagen = partial(np.random.laplace, size=(100, 100))
166
- >>> x = Minibatch(datagen(), batch_size=100 , update_shared_f=datagen)
172
+ >>> x = Minibatch(datagen(), batch_size=10 , update_shared_f=datagen)
167
173
>>> x.update_shared()
168
174
169
175
To be more concrete about how we get minibatch, here is a demo
@@ -176,55 +182,55 @@ class Minibatch(tt.TensorVariable):
176
182
3) take that slice
177
183
>>> minibatch = shared[ridx]
178
184
179
- That's done. So if you'll need some replacements in the graph
185
+ That's done. Next you can use this minibatch somewhere else.
186
+ You can see that in implementation minibatch does not require
187
+ fixed shape for shared variable. Feel free to use that if needed.
188
+
189
+ So if you'll need some replacements in the graph, e.g. change it to testdata
180
190
>>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))
181
191
182
- To change minibatch with static data you can create a dict with replacements
192
+ You can change minibatch with static data you can create a dict with replacements
183
193
>>> replacements = {x: testdata}
184
194
>>> node = x ** 2 # arbitrary expressions
185
195
>>> rnode = theano.clone(node, replacements)
186
196
>>> assert (testdata ** 2 == rnode.eval()).all()
187
-
197
+
188
198
To replace minibatch with it's shared variable
189
- instead of static `np.array ` you should do
199
+ instead of static :class:`ndarray ` you should do
190
200
>>> replacements = {x.minibatch: x.shared}
191
201
>>> rnode = theano.clone(node, replacements)
192
202
193
203
For more complex slices some more code is needed that can seem not so clear
194
- They are
195
204
>>> moredata = np.random.rand(10, 20, 30, 40, 50)
196
-
197
- default total_size is then (10, 20, 30, 40, 50) but
198
- can be less verbose in some cases
205
+
206
+ default ` total_size` that can be passed to `PyMC3` random node
207
+ is then `(10, 20, 30, 40, 50)` but can be less verbose in some cases
199
208
200
209
1) Advanced indexing, `total_size = (10, Ellipsis, 50)`
201
210
>>> x = Minibatch(moredata, [2, Ellipsis, 10])
202
211
203
212
We take slice only for the first and last dimension
204
213
>>> assert x.eval().shape == (2, 20, 30, 40, 10)
205
214
206
- 2) skipping particular dimension, total_size = (10, None, 30)
215
+ 2) Skipping particular dimension, ` total_size = (10, None, 30)`
207
216
>>> x = Minibatch(moredata, [2, None, 20])
208
217
>>> assert x.eval().shape == (2, 20, 20, 40, 50)
209
218
210
- 3) mixing that all, total_size = (10, None, 30, Ellipsis, 50)
219
+ 3) Mixing that all, ` total_size = (10, None, 30, Ellipsis, 50)`
211
220
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
212
221
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
213
222
"""
214
223
@theano .configparser .change_flags (compute_test_value = 'raise' )
215
- def __init__ (self , data , batch_size = 128 , in_memory_size = None ,
216
- random_seed = 42 , update_shared_f = None ,
217
- broadcastable = None , dtype = None , name = 'Minibatch' ):
224
+ def __init__ (self , data , batch_size = 128 , dtype = None , broadcastable = None , name = 'Minibatch' ,
225
+ random_seed = 42 , update_shared_f = None , in_memory_size = None ):
218
226
if dtype is None :
219
227
data = pm .smartfloatX (np .asarray (data ))
220
228
else :
221
229
data = np .asarray (data , dtype )
222
- self ._random_seed = random_seed
223
- in_memory_slc = self ._to_slices (in_memory_size )
224
- self .batch_size = batch_size
230
+ in_memory_slc = self .make_static_slices (in_memory_size )
225
231
self .shared = theano .shared (data [in_memory_slc ])
226
232
self .update_shared_f = update_shared_f
227
- self .random_slc = self ._to_random_slices (self .shared .shape , batch_size )
233
+ self .random_slc = self .make_random_slices (self .shared .shape , batch_size , random_seed )
228
234
minibatch = self .shared [self .random_slc ]
229
235
if broadcastable is None :
230
236
broadcastable = (False , ) * minibatch .ndim
@@ -249,7 +255,7 @@ def rslice(total, size, seed):
249
255
raise TypeError ('Unrecognized size type, %r' % size )
250
256
251
257
@staticmethod
252
- def _to_slices (user_size ):
258
+ def make_static_slices (user_size ):
253
259
if user_size is None :
254
260
return [Ellipsis ]
255
261
elif isinstance (user_size , int ):
@@ -271,11 +277,12 @@ def _to_slices(user_size):
271
277
else :
272
278
raise TypeError ('Unrecognized size type, %r' % user_size )
273
279
274
- def _to_random_slices (self , in_memory_shape , batch_size ):
280
+ @classmethod
281
+ def make_random_slices (cls , in_memory_shape , batch_size , default_random_seed ):
275
282
if batch_size is None :
276
283
return [Ellipsis ]
277
284
elif isinstance (batch_size , int ):
278
- slc = [self .rslice (in_memory_shape [0 ], batch_size , self . _random_seed )]
285
+ slc = [cls .rslice (in_memory_shape [0 ], batch_size , default_random_seed )]
279
286
elif isinstance (batch_size , (list , tuple )):
280
287
def check (t ):
281
288
if t is Ellipsis or t is None :
@@ -297,7 +304,7 @@ def check(t):
297
304
'size and random seed are both ints, got %r' %
298
305
batch_size )
299
306
batch_size = [
300
- (i , self . _random_seed ) if isinstance (i , int ) else i
307
+ (i , default_random_seed ) if isinstance (i , int ) else i
301
308
for i in batch_size
302
309
]
303
310
shape = in_memory_shape
@@ -326,10 +333,10 @@ def check(t):
326
333
else :
327
334
shp_end = np .asarray ([])
328
335
shp_begin = shape [:len (begin )]
329
- slc_begin = [self .rslice (shp_begin [i ], t [0 ], t [1 ])
336
+ slc_begin = [cls .rslice (shp_begin [i ], t [0 ], t [1 ])
330
337
if t is not None else tt .arange (shp_begin [i ])
331
338
for i , t in enumerate (begin )]
332
- slc_end = [self .rslice (shp_end [i ], t [0 ], t [1 ])
339
+ slc_end = [cls .rslice (shp_end [i ], t [0 ], t [1 ])
333
340
if t is not None else tt .arange (shp_end [i ])
334
341
for i , t in enumerate (end )]
335
342
slc = slc_begin + mid + slc_end
@@ -339,6 +346,8 @@ def check(t):
339
346
return pm .theanof .ix_ (* slc )
340
347
341
348
def update_shared (self ):
349
+ if self .update_shared_f is None :
350
+ raise NotImplementedError ("No `update_shared_f` was provided to `__init__`" )
342
351
self .set_value (np .asarray (self .update_shared_f (), self .dtype ))
343
352
344
353
def set_value (self , value ):
0 commit comments