Skip to content

Commit 61c6c48

Browse files
ferrinetaku-y
authored andcommitted
delete DataSampler
1 parent 3c9384a commit 61c6c48

File tree

3 files changed

+39
-69
lines changed

3 files changed

+39
-69
lines changed

pymc3/data.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
__all__ = [
1313
'get_data',
1414
'GeneratorAdapter',
15-
'DataSampler',
1615
'Minibatch'
1716
]
1817

@@ -92,64 +91,6 @@ def __hash__(self):
9291
return hash(id(self))
9392

9493

95-
class DataSampler(object):
96-
"""
97-
Convenient picklable data sampler for minibatch inference.
98-
99-
This generator can be used for passing to pm.generator
100-
creating picklable theano computational grapf
101-
102-
Parameters
103-
----------
104-
data : array like
105-
batchsize : sample size over zero axis
106-
random_seed : int for numpy random generator
107-
dtype : str representing dtype
108-
109-
Usage
110-
-----
111-
>>> import pickle
112-
>>> from functools import partial
113-
>>> np.random.seed(42) # reproducibility
114-
>>> pm.set_tt_rng(42)
115-
>>> data = np.random.normal(size=(1000,)) + 10
116-
>>> minibatches = DataSampler(data, batchsize=50)
117-
>>> with pm.Model():
118-
... sd = pm.Uniform('sd', 0, 10)
119-
... mu = pm.Normal('mu', sd=10)
120-
... obs_norm = pm.Normal('obs_norm', mu=mu, sd=sd,
121-
... observed=minibatches,
122-
... total_size=data.shape[0])
123-
... adam = partial(pm.adam, learning_rate=.8) # easy problem
124-
... approx = pm.fit(10000, method='advi', obj_optimizer=adam)
125-
>>> new = pickle.loads(pickle.dumps(approx))
126-
>>> new #doctest: +ELLIPSIS
127-
<pymc3.variational.approximations.MeanField object at 0x...>
128-
>>> new.sample(draws=1000)['mu'].mean()
129-
10.08339999101371
130-
>>> new.sample(draws=1000)['sd'].mean()
131-
1.2178044136104513
132-
"""
133-
def __init__(self, data, batchsize=50, random_seed=42, dtype='floatX'):
134-
self.dtype = theano.config.floatX if dtype == 'floatX' else dtype
135-
self.rng = np.random.RandomState(random_seed)
136-
self.data = data
137-
self.n = batchsize
138-
139-
def __iter__(self):
140-
return self
141-
142-
def __next__(self):
143-
idx = (self.rng
144-
.uniform(size=self.n,
145-
low=0.0,
146-
high=self.data.shape[0] - 1e-16)
147-
.astype('int64'))
148-
return np.asarray(self.data[idx], self.dtype)
149-
150-
next = __next__
151-
152-
15394
class Minibatch(tt.TensorVariable):
15495
"""Multidimensional minibatch
15596

pymc3/tests/conftest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,32 @@
11
import theano
2+
import numpy as np
23
import pytest
34

45

6+
class DataSampler(object):
7+
"""
8+
Not for users
9+
"""
10+
def __init__(self, data, batchsize=50, random_seed=42, dtype='floatX'):
11+
self.dtype = theano.config.floatX if dtype == 'floatX' else dtype
12+
self.rng = np.random.RandomState(random_seed)
13+
self.data = data
14+
self.n = batchsize
15+
16+
def __iter__(self):
17+
return self
18+
19+
def __next__(self):
20+
idx = (self.rng
21+
.uniform(size=self.n,
22+
low=0.0,
23+
high=self.data.shape[0] - 1e-16)
24+
.astype('int64'))
25+
return np.asarray(self.data[idx], self.dtype)
26+
27+
next = __next__
28+
29+
530
@pytest.fixture(scope="session", autouse=True)
631
def theano_config():
732
config = theano.configparser.change_flags(compute_test_value='raise')
@@ -16,3 +41,10 @@ def strict_float32():
1641
floatX='float32')
1742
with config:
1843
yield
44+
45+
46+
@pytest.fixture('session', params=[
47+
np.random.uniform(size=(1000, 10))
48+
])
49+
def datagen(request):
50+
return DataSampler(request.param)

pymc3/tests/test_minibatches.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import theano
99

1010
import pymc3 as pm
11-
from pymc3 import floatX, GeneratorAdapter, generator, DataSampler, tt_rng, Normal
11+
from pymc3 import floatX, GeneratorAdapter, generator, tt_rng, Normal
1212
from pymc3.tests.helpers import select_by_precision
1313
from pymc3.theanof import GeneratorOp
1414

@@ -29,6 +29,7 @@ def integers_ndim(ndim):
2929

3030
@pytest.mark.usefixtures('strict_float32')
3131
class TestGenerator(object):
32+
3233
def test_basic(self):
3334
generator = GeneratorAdapter(integers())
3435
gop = GeneratorOp(generator)()
@@ -86,24 +87,20 @@ def gen():
8687
np.testing.assert_equal(np.ones((10, 10)) * 0, f())
8788
np.testing.assert_equal(np.ones((10, 10)) * 1, f())
8889

89-
def test_pickling(self):
90-
data = np.random.uniform(size=(1000, 10))
91-
minibatches = DataSampler(data, batchsize=50)
92-
gen = generator(minibatches)
90+
def test_pickling(self, datagen):
91+
gen = generator(datagen)
9392
pickle.loads(pickle.dumps(gen))
9493
bad_gen = generator(integers())
9594
with pytest.raises(Exception):
9695
pickle.dumps(bad_gen)
9796

98-
def test_gen_cloning_with_shape_change(self):
99-
data = floatX(np.random.uniform(size=(1000, 10)))
100-
minibatches = DataSampler(data, batchsize=50)
101-
gen = generator(minibatches)
97+
def test_gen_cloning_with_shape_change(self, datagen):
98+
gen = generator(datagen)
10299
gen_r = tt_rng().normal(size=gen.shape).T
103100
X = gen.dot(gen_r)
104101
res, _ = theano.scan(lambda x: x.sum(), X, n_steps=X.shape[0])
105102
assert res.eval().shape == (50,)
106-
shared = theano.shared(data)
103+
shared = theano.shared(datagen.data.astype(gen.dtype))
107104
res2 = theano.clone(res, {gen: shared**2})
108105
assert res2.eval().shape == (1000,)
109106

0 commit comments

Comments
 (0)