@@ -144,7 +144,7 @@ class Minibatch(tt.TensorVariable):
144
144
>>> with pm.Model() as model:
145
145
... mu = pm.Flat('mu')
146
146
... sd = pm.HalfNormal('sd')
147
- ... lik = pm.Normal('lik', mu, sd, observed=x)
147
+ ... lik = pm.Normal('lik', mu, sd, observed=x, total_size=(100, 100) )
148
148
149
149
Then you can perform regular Variational Inference out of the box
150
150
>>> with model:
@@ -188,18 +188,21 @@ class Minibatch(tt.TensorVariable):
188
188
For more complex slices some more code is needed that can seem not so clear
189
189
They are
190
190
>>> moredata = np.random.rand(10, 20, 30, 40, 50)
191
+
192
+ default total_size is then (10, 20, 30, 40, 50) but
193
+ can be less verbose in sove cases
191
194
192
- 1) Advanced indexing
195
+ 1) Advanced indexing, `total_size = (10, Ellipsis, 50)`
193
196
>>> x = Minibatch(moredata, [2, Ellipsis, 10])
194
197
195
198
We take slice only for the first and last dimension
196
199
>>> assert x.eval().shape == (2, 20, 30, 40, 10)
197
200
198
- 2) skipping particular dimension
201
+ 2) skipping particular dimension, total_size = (10, None, 30)
199
202
>>> x = Minibatch(moredata, [2, None, 20])
200
203
>>> assert x.eval().shape == (2, 20, 20, 40, 50)
201
204
202
- 3) mixing that all
205
+ 3) mixing that all, total_size = (10, None, 30, Ellipsis, 50)
203
206
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
204
207
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
205
208
0 commit comments