5
5
6
6
import numpy as np
7
7
import pymc3 as pm
8
+ from pymc3 import theanof
8
9
import theano .tensor as tt
9
10
import theano
10
11
@@ -149,7 +150,7 @@ def __next__(self):
149
150
next = __next__
150
151
151
152
152
- class Minibatch (object ):
153
+ class Minibatch (tt . TensorVariable ):
153
154
"""Multidimensional minibatch
154
155
155
156
Parameters
@@ -173,18 +174,31 @@ class Minibatch(object):
173
174
minibatch : minibatch tensor
174
175
Used for training
175
176
"""
176
- def __init__ (self , data , batch_size = 128 , in_memory_size = None , random_seed = 42 , update_shared_f = None ):
177
+ @theanof .change_flags (compute_test_value = 'raise' )
178
+ def __init__ (self , data , batch_size = 128 , in_memory_size = None ,
179
+ random_seed = 42 , update_shared_f = None ,
180
+ broadcastable = None , name = 'Minibatch' ):
177
181
data = pm .smartfloatX (np .asarray (data ))
178
182
self ._random_seed = random_seed
179
183
in_memory_slc = self ._to_slices (in_memory_size )
180
- self .data = data
181
184
self .batch_size = batch_size
182
185
self .shared = theano .shared (data [in_memory_slc ])
183
186
self .update_shared_f = update_shared_f
184
187
self .random_slc = self ._to_random_slices (self .shared .shape , batch_size )
185
- self .minibatch = self .shared [self .random_slc ]
188
+ minibatch = self .shared [self .random_slc ]
189
+ if broadcastable is None :
190
+ broadcastable = (False , ) * minibatch .ndim
191
+ minibatch = tt .patternbroadcast (minibatch , broadcastable )
192
+ self .minibatch = minibatch
193
+ super (Minibatch , self ).__init__ (
194
+ self .minibatch .type , None , None , name = name )
195
+ theano .Apply (
196
+ theano .compile .view_op ,
197
+ inputs = [self .minibatch ], outputs = [self ])
198
+ self .tag .test_value = copy (self .minibatch .tag .test_value )
186
199
187
- def rslice (self , total , size , seed ):
200
+ @staticmethod
201
+ def rslice (total , size , seed ):
188
202
if size is None :
189
203
return slice (None )
190
204
elif isinstance (size , int ):
@@ -284,13 +298,8 @@ def update_shared(self):
284
298
def set_value (self , value ):
285
299
self .shared .set_value (np .asarray (value , self .dtype ))
286
300
287
- @property
288
- def dtype (self ):
289
- return self .shared .dtype
290
-
291
- @property
292
- def type (self ):
293
- return self .shared .type
294
-
295
- def __repr__ (self ):
296
- return '<Minibatch of %s>' % self .batch_size
301
+ def clone (self ):
302
+ ret = self .type ()
303
+ ret .name = self .name
304
+ ret .tag = copy (self .tag )
305
+ return ret
0 commit comments