Skip to content

Commit 1e007fa

Browse files
ferrinetaku-y
authored andcommitted
add total size hints
1 parent 7180472 commit 1e007fa

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

pymc3/data.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class Minibatch(tt.TensorVariable):
144144
>>> with pm.Model() as model:
145145
... mu = pm.Flat('mu')
146146
... sd = pm.HalfNormal('sd')
147-
... lik = pm.Normal('lik', mu, sd, observed=x)
147+
... lik = pm.Normal('lik', mu, sd, observed=x, total_size=(100, 100))
148148
149149
Then you can perform regular Variational Inference out of the box
150150
>>> with model:
@@ -188,18 +188,21 @@ class Minibatch(tt.TensorVariable):
188188
For more complex slices some more code is needed that can seem not so clear
189189
They are
190190
>>> moredata = np.random.rand(10, 20, 30, 40, 50)
191+
192+
default total_size is then (10, 20, 30, 40, 50) but
193+
can be less verbose in sove cases
191194
192-
1) Advanced indexing
195+
1) Advanced indexing, `total_size = (10, Ellipsis, 50)`
193196
>>> x = Minibatch(moredata, [2, Ellipsis, 10])
194197
195198
We take slice only for the first and last dimension
196199
>>> assert x.eval().shape == (2, 20, 30, 40, 10)
197200
198-
2) skipping particular dimension
201+
2) skipping particular dimension, total_size = (10, None, 30)
199202
>>> x = Minibatch(moredata, [2, None, 20])
200203
>>> assert x.eval().shape == (2, 20, 20, 40, 50)
201204
202-
3) mixing that all
205+
3) mixing that all, total_size = (10, None, 30, Ellipsis, 50)
203206
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
204207
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
205208

pymc3/tests/test_minibatches.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ def test_common_errors(self):
223223
Normal('n', observed=[[1]], total_size=[Ellipsis, Ellipsis])
224224
assert 'Double Ellipsis' in str(e.value)
225225

226+
def test_mixed1(self):
227+
with pm.Model():
228+
data = np.random.rand(10, 20, 30, 40, 50)
229+
mb = pm.Minibatch(data, [2, None, 20, Ellipsis, 10])
230+
Normal('n', observed=mb, total_size=(10, None, 30, Ellipsis, 50))
231+
232+
def test_mixed2(self):
233+
with pm.Model():
234+
data = np.random.rand(10, 20, 30, 40, 50)
235+
mb = pm.Minibatch(data, [2, None, 20])
236+
Normal('n', observed=mb, total_size=(10, None, 30))
237+
226238
def test_free_rv(self):
227239
with pm.Model() as model4:
228240
Normal('n', observed=[[1, 1],

0 commit comments

Comments
 (0)