15
15
import pandas as pd
16
16
17
17
from aesara .compile import SharedVariable
18
- from aesara .graph .basic import Apply , Constant , clone , graph_inputs
18
+ from aesara .graph .basic import clone_replace , graph_inputs
19
19
from aesara .graph .fg import FunctionGraph
20
- from aesara .graph .op import Op
21
20
from aesara .graph .opt import MergeOptimizer
22
21
from aesara .link .jax .dispatch import jax_funcify
23
- from aesara .tensor .type import TensorType
24
22
25
23
from pymc import modelcontext
26
24
from pymc .aesaraf import compile_rv_inplace
27
25
28
26
warnings .warn ("This module is experimental." )
29
27
30
28
31
- class NumPyroNUTS (Op ):
32
- def __init__ (
33
- self ,
34
- inputs ,
35
- outputs ,
36
- target_accept = 0.8 ,
37
- draws = 1000 ,
38
- tune = 1000 ,
39
- chains = 4 ,
40
- seed = None ,
41
- progress_bar = True ,
42
- ):
43
- self .draws = draws
44
- self .tune = tune
45
- self .chains = chains
46
- self .target_accept = target_accept
47
- self .progress_bar = progress_bar
48
- self .seed = seed
29
+ def replace_shared_variables (graph ):
30
+ """Replace shared variables in graph by their constant values
49
31
50
- self .inputs , self .outputs = clone (inputs , outputs , copy_inputs = False )
51
- self .inputs_type = tuple (input .type for input in inputs )
52
- self .outputs_type = tuple (output .type for output in outputs )
53
- self .nin = len (inputs )
54
- self .nout = len (outputs )
55
- self .nshared = len ([v for v in inputs if isinstance (v , SharedVariable )])
56
- self .samples_bcast = [self .chains == 1 , self .draws == 1 ]
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If any shared variable contains default_updates
36
+ """
57
37
58
- self .fgraph = FunctionGraph (self .inputs , self .outputs , clone = False )
59
- MergeOptimizer ().optimize (self .fgraph )
38
+ shared_variables = [var for var in graph_inputs (graph ) if isinstance (var , SharedVariable )]
60
39
61
- super ().__init__ ()
62
-
63
- def make_node (self , * inputs ):
64
-
65
- # The samples for each variable
66
- outputs = [
67
- TensorType (v .dtype , self .samples_bcast + list (v .broadcastable ))() for v in inputs
68
- ]
69
-
70
- # The leapfrog statistics
71
- outputs += [TensorType ("int64" , self .samples_bcast )()]
72
-
73
- all_inputs = list (inputs )
74
- if self .nshared > 0 :
75
- all_inputs += self .inputs [- self .nshared :]
76
-
77
- return Apply (self , all_inputs , outputs )
78
-
79
- def do_constant_folding (self , * args ):
80
- return False
81
-
82
- def perform (self , node , inputs , outputs ):
83
- raise NotImplementedError ()
84
-
85
-
86
- @jax_funcify .register (NumPyroNUTS )
87
- def jax_funcify_NumPyroNUTS (op , node , ** kwargs ):
88
- from numpyro .infer import MCMC , NUTS
89
-
90
- draws = op .draws
91
- tune = op .tune
92
- chains = op .chains
93
- target_accept = op .target_accept
94
- progress_bar = op .progress_bar
95
- seed = op .seed
96
-
97
- # Compile the "inner" log-likelihood function. This will have extra shared
98
- # variable inputs as the last arguments
99
- logp_fn = jax_funcify (op .fgraph , ** kwargs )
100
-
101
- if isinstance (logp_fn , (list , tuple )):
102
- # This handles the new JAX backend, which always returns a tuple
103
- logp_fn = logp_fn [0 ]
104
-
105
- def _sample (* inputs ):
106
-
107
- if op .nshared > 0 :
108
- current_state = inputs [: - op .nshared ]
109
- shared_inputs = tuple (op .fgraph .inputs [- op .nshared :])
110
- else :
111
- current_state = inputs
112
- shared_inputs = ()
113
-
114
- def log_fn_wrap (x ):
115
- res = logp_fn (
116
- * (
117
- x
118
- # We manually obtain the shared values and added them
119
- # as arguments to our compiled "inner" function
120
- + tuple (
121
- v .get_value (borrow = True , return_internal_type = True ) for v in shared_inputs
122
- )
123
- )
124
- )
125
-
126
- if isinstance (res , (list , tuple )):
127
- # This handles the new JAX backend, which always returns a tuple
128
- res = res [0 ]
129
-
130
- return - res
131
-
132
- nuts_kernel = NUTS (
133
- potential_fn = log_fn_wrap ,
134
- target_accept_prob = target_accept ,
135
- adapt_step_size = True ,
136
- adapt_mass_matrix = True ,
137
- dense_mass = False ,
138
- )
139
-
140
- pmap_numpyro = MCMC (
141
- nuts_kernel ,
142
- num_warmup = tune ,
143
- num_samples = draws ,
144
- num_chains = chains ,
145
- postprocess_fn = None ,
146
- chain_method = "parallel" ,
147
- progress_bar = progress_bar ,
40
+ if any (hasattr (var , "default_update" ) for var in shared_variables ):
41
+ raise ValueError (
42
+ "Graph contains shared variables with default_update which cannot "
43
+ "be safely replaced."
148
44
)
149
45
150
- pmap_numpyro .run (seed , init_params = current_state , extra_fields = ("num_steps" ,))
151
- samples = pmap_numpyro .get_samples (group_by_chain = True )
152
- leapfrogs_taken = pmap_numpyro .get_extra_fields (group_by_chain = True )["num_steps" ]
153
- return tuple (samples ) + (leapfrogs_taken ,)
46
+ replacements = {var : at .constant (var .get_value (borrow = True )) for var in shared_variables }
154
47
155
- return _sample
48
+ new_graph = clone_replace (graph , replace = replacements )
49
+ return new_graph
156
50
157
51
158
52
def sample_numpyro_nuts (
@@ -165,72 +59,101 @@ def sample_numpyro_nuts(
165
59
progress_bar = True ,
166
60
keep_untransformed = False ,
167
61
):
62
+ from numpyro .infer import MCMC , NUTS
63
+
168
64
model = modelcontext (model )
169
65
170
- seed = jax .random .PRNGKey (random_seed )
66
+ tic1 = pd .Timestamp .now ()
67
+ print ("Compiling..." , file = sys .stdout )
171
68
172
69
rv_names = [rv .name for rv in model .value_vars ]
173
70
init_state = [model .initial_point [rv_name ] for rv_name in rv_names ]
174
71
init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
175
- init_state_batched_at = [at .as_tensor (v ) for v in init_state_batched ]
176
72
177
- nuts_inputs = sorted (
178
- (v for v in graph_inputs ([model .logpt ]) if not isinstance (v , Constant )),
179
- key = lambda x : isinstance (x , SharedVariable ),
180
- )
181
- map_seed = jax .random .split (seed , chains )
182
- numpyro_samples = NumPyroNUTS (
183
- nuts_inputs ,
184
- [model .logpt ],
185
- target_accept = target_accept ,
186
- draws = draws ,
187
- tune = tune ,
188
- chains = chains ,
189
- seed = map_seed ,
190
- progress_bar = progress_bar ,
191
- )(* init_state_batched_at )
73
+ logpt = replace_shared_variables ([model .logpt ])[0 ]
74
+ logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = False )
75
+ MergeOptimizer ().optimize (logpt_fgraph )
76
+ logp_fn = jax_funcify (logpt_fgraph )
192
77
193
- # Un-transform the transformed variables in JAX
194
- sample_outputs = []
195
- for i , (value_var , rv_samples ) in enumerate (zip (model .value_vars , numpyro_samples [:- 1 ])):
196
- rv = model .values_to_rvs [value_var ]
197
- transform = getattr (value_var .tag , "transform" , None )
198
- if transform is not None :
199
- untrans_value_var = transform .backward (rv , rv_samples )
200
- untrans_value_var .name = rv .name
201
- sample_outputs .append (untrans_value_var )
78
+ if isinstance (logp_fn , (list , tuple )):
79
+ # This handles the new JAX backend, which always returns a tuple
80
+ logp_fn = logp_fn [0 ]
202
81
203
- if keep_untransformed :
204
- rv_samples .name = value_var .name
205
- sample_outputs .append (rv_samples )
206
- else :
207
- rv_samples .name = rv .name
208
- sample_outputs .append (rv_samples )
82
+ def logp_fn_wrap (x ):
83
+ res = logp_fn (* x )
209
84
210
- print ("Compiling..." , file = sys .stdout )
85
+ if isinstance (res , (list , tuple )):
86
+ # This handles the new JAX backend, which always returns a tuple
87
+ res = res [0 ]
211
88
212
- tic1 = pd .Timestamp .now ()
213
- _sample = compile_rv_inplace (
214
- [],
215
- sample_outputs + [numpyro_samples [- 1 ]],
216
- allow_input_downcast = True ,
217
- on_unused_input = "ignore" ,
218
- accept_inplace = True ,
219
- mode = "JAX" ,
89
+ # Jax expects a potential with the opposite sign of model.logpt
90
+ return - res
91
+
92
+ nuts_kernel = NUTS (
93
+ potential_fn = logp_fn_wrap ,
94
+ target_accept_prob = target_accept ,
95
+ adapt_step_size = True ,
96
+ adapt_mass_matrix = True ,
97
+ dense_mass = False ,
98
+ )
99
+
100
+ pmap_numpyro = MCMC (
101
+ nuts_kernel ,
102
+ num_warmup = tune ,
103
+ num_samples = draws ,
104
+ num_chains = chains ,
105
+ postprocess_fn = None ,
106
+ chain_method = "parallel" ,
107
+ progress_bar = progress_bar ,
220
108
)
221
- tic2 = pd .Timestamp .now ()
222
109
110
+ tic2 = pd .Timestamp .now ()
223
111
print ("Compilation time = " , tic2 - tic1 , file = sys .stdout )
224
112
225
113
print ("Sampling..." , file = sys .stdout )
226
114
227
- * mcmc_samples , leapfrogs_taken = _sample ( )
228
- tic3 = pd . Timestamp . now ( )
115
+ seed = jax . random . PRNGKey ( random_seed )
116
+ map_seed = jax . random . split ( seed , chains )
229
117
118
+ pmap_numpyro .run (map_seed , init_params = init_state_batched , extra_fields = ("num_steps" ,))
119
+ raw_mcmc_samples = pmap_numpyro .get_samples (group_by_chain = True )
120
+
121
+ tic3 = pd .Timestamp .now ()
230
122
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
231
123
232
- posterior = {k .name : v for k , v in zip (sample_outputs , mcmc_samples )}
124
+ print ("Transforming variables..." , file = sys .stdout )
125
+ mcmc_samples = []
126
+ for i , (value_var , raw_samples ) in enumerate (zip (model .value_vars , raw_mcmc_samples )):
127
+ raw_samples = at .constant (np .asarray (raw_samples ))
128
+
129
+ rv = model .values_to_rvs [value_var ]
130
+ transform = getattr (value_var .tag , "transform" , None )
131
+
132
+ if transform is not None :
133
+ # TODO: This will fail when the transformation depends on another variable
134
+ # such as in interval transform with RVs as edges
135
+ trans_samples = transform .backward (rv , raw_samples )
136
+ trans_samples .name = rv .name
137
+ mcmc_samples .append (trans_samples )
138
+
139
+ if keep_untransformed :
140
+ raw_samples .name = value_var .name
141
+ mcmc_samples .append (raw_samples )
142
+ else :
143
+ raw_samples .name = rv .name
144
+ mcmc_samples .append (raw_samples )
145
+
146
+ mcmc_varnames = [var .name for var in mcmc_samples ]
147
+ mcmc_samples = compile_rv_inplace (
148
+ [],
149
+ mcmc_samples ,
150
+ mode = "JAX" ,
151
+ )()
152
+
153
+ tic4 = pd .Timestamp .now ()
154
+ print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
233
155
156
+ posterior = {k : v for k , v in zip (mcmc_varnames , mcmc_samples )}
234
157
az_trace = az .from_dict (posterior = posterior )
235
158
236
159
return az_trace
0 commit comments