@@ -91,7 +91,7 @@ def __hash__(self):
91
91
92
92
93
93
class Minibatch (tt .TensorVariable ):
94
- """Multidimensional minibatch
94
+ """Multidimensional minibatch that is pure TensorVariable
95
95
96
96
Parameters
97
97
----------
@@ -122,12 +122,16 @@ class Minibatch(tt.TensorVariable):
122
122
Consider we have data
123
123
>>> data = np.random.rand(100, 100)
124
124
125
- if we want 1d slice of size 10
125
+ if we want 1d slice of size 10 we do
126
126
>>> x = Minibatch(data, batch_size=10)
127
+
128
+ Note, that your data is casted to `floatX` if it is not integer type
129
+ But you still can add dtype kwarg for :class:`Minibatch`
127
130
128
131
in case we want 10 sampled rows and columns
129
- [(size, seed), (size, seed)]
130
- >>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)])
132
+ [(size, seed), (size, seed)] it is
133
+ >>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)], dtype='int32')
134
+ >>> assert str(x.dtype) == 'int32'
131
135
132
136
or simpler with default random seed = 42
133
137
[size, size]
@@ -146,12 +150,21 @@ class Minibatch(tt.TensorVariable):
146
150
>>> with model:
147
151
... approx = pm.fit()
148
152
149
- Notable thing is that Minibatch has `shared`, `minibatch`, attributes
153
+ Notable thing is that :class:` Minibatch` has `shared`, `minibatch`, attributes
150
154
you can call later
151
155
>>> x.set_value(np.random.laplace(size=(100, 100)))
152
156
153
157
and minibatches will be then from new storage
154
- it directly affects `x.shared`.
158
+ it directly affects `x.shared`.
159
+ the same thing would be but less convenient
160
+ >>> x.shared.set_value(pm.floatX(np.random.laplace(size=(100, 100))))
161
+
162
+ programmatic way to change storage is as following
163
+ I import `partial` for simplicity
164
+ >>> from functools import partial
165
+ >>> datagen = partial(np.random.laplace, size=(100, 100))
166
+ >>> x = Minibatch(datagen(), batch_size=100, update_shared_f=datagen)
167
+ >>> x.update_shared()
155
168
156
169
To be more precise of how we get minibatch, here is a demo
157
170
1) create shared variable
@@ -166,7 +179,7 @@ class Minibatch(tt.TensorVariable):
166
179
That's done. So if you'll need some replacements in the graph
167
180
>>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))
168
181
169
- you are free to use a kind of this one as `x` as it is Theano Tensor
182
+ you are free to use a kind of this one as `x` is regular Theano Tensor
170
183
>>> replacements = {x: testdata}
171
184
>>> node = x ** 2 # arbitrary expressions
172
185
>>> rnode = theano.clone(node, replacements)
@@ -194,8 +207,11 @@ class Minibatch(tt.TensorVariable):
194
207
@theano .configparser .change_flags (compute_test_value = 'raise' )
195
208
def __init__ (self , data , batch_size = 128 , in_memory_size = None ,
196
209
random_seed = 42 , update_shared_f = None ,
197
- broadcastable = None , name = 'Minibatch' ):
198
- data = pm .smartfloatX (np .asarray (data ))
210
+ broadcastable = None , dtype = None , name = 'Minibatch' ):
211
+ if dtype is None :
212
+ data = pm .smartfloatX (np .asarray (data ))
213
+ else :
214
+ data = np .asarray (data , dtype )
199
215
self ._random_seed = random_seed
200
216
in_memory_slc = self ._to_slices (in_memory_size )
201
217
self .batch_size = batch_size
0 commit comments