1
- import warnings
2
1
import os
3
2
import re
4
- xla_flags = os .getenv ('XLA_FLAGS' , '' ).lstrip ('--' )
5
- xla_flags = re .sub (r'xla_force_host_platform_device_count=.+\s' , '' , xla_flags ).split ()
6
- os .environ ['XLA_FLAGS' ] = ' ' .join (['--xla_force_host_platform_device_count={}' .format (100 )])
3
+ import warnings
7
4
5
+ xla_flags = os .getenv ("XLA_FLAGS" , "" ).lstrip ("--" )
6
+ xla_flags = re .sub (r"xla_force_host_platform_device_count=.+\s" , "" , xla_flags ).split ()
7
+ os .environ ["XLA_FLAGS" ] = " " .join (["--xla_force_host_platform_device_count={}" .format (100 )])
8
+
9
+ import arviz as az
10
+ import jax
8
11
import numpy as np
9
12
import pandas as pd
10
-
11
13
import theano
12
14
import theano .sandbox .jax_linker
13
15
import theano .sandbox .jaxify
14
- import jax
15
16
16
- import arviz as az
17
17
import pymc3 as pm
18
+
18
19
from pymc3 import modelcontext
19
20
20
21
warnings .warn ("This module is experimental." )
24
25
# This will make the JAX Linker the default
25
26
# theano.config.mode = "JAX"
26
27
27
- def sample_tfp_nuts (draws = 1000 , tune = 1000 , chains = 4 , target_accept = 0.8 , random_seed = 10 , model = None ,
28
- num_tuning_epoch = 2 , num_compute_step_size = 500 ):
28
+
29
+ def sample_tfp_nuts (
30
+ draws = 1000 ,
31
+ tune = 1000 ,
32
+ chains = 4 ,
33
+ target_accept = 0.8 ,
34
+ random_seed = 10 ,
35
+ model = None ,
36
+ num_tuning_epoch = 2 ,
37
+ num_compute_step_size = 500 ,
38
+ ):
29
39
from tensorflow_probability .substrates import jax as tfp
40
+
30
41
model = modelcontext (model )
31
-
42
+
32
43
seed = jax .random .PRNGKey (random_seed )
33
-
44
+
34
45
fgraph = theano .gof .FunctionGraph (model .free_RVs , [model .logpt ])
35
46
fns = theano .sandbox .jaxify .jax_funcify (fgraph )
36
47
logp_fn_jax = fns [0 ]
37
48
38
49
rv_names = [rv .name for rv in model .free_RVs ]
39
50
init_state = [model .test_point [rv_name ] for rv_name in rv_names ]
40
- init_state_batched = jax .tree_map (
41
- lambda x : np .repeat (x [None , ...], chains , axis = 0 ),
42
- init_state )
51
+ init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
43
52
44
53
@jax .pmap
45
54
def _sample (init_state , seed ):
46
55
def gen_kernel (step_size ):
47
- hmc = tfp .mcmc .NoUTurnSampler (
48
- target_log_prob_fn = logp_fn_jax , step_size = step_size )
56
+ hmc = tfp .mcmc .NoUTurnSampler (target_log_prob_fn = logp_fn_jax , step_size = step_size )
49
57
return tfp .mcmc .DualAveragingStepSizeAdaptation (
50
- hmc , tune // num_tuning_epoch ,
51
- target_accept_prob = target_accept )
58
+ hmc , tune // num_tuning_epoch , target_accept_prob = target_accept
59
+ )
52
60
53
61
def trace_fn (_ , pkr ):
54
62
return pkr .new_step_size
55
-
63
+
56
64
def get_tuned_stepsize (samples , step_size ):
57
65
return step_size [- 1 ] * jax .numpy .std (samples [- num_compute_step_size :])
58
66
59
67
step_size = jax .tree_map (jax .numpy .ones_like , init_state )
60
- for i in range (num_tuning_epoch - 1 ):
68
+ for i in range (num_tuning_epoch - 1 ):
61
69
tuning_hmc = gen_kernel (step_size )
62
70
init_samples , tuning_result , kernel_results = tfp .mcmc .sample_chain (
63
71
num_results = tune // num_tuning_epoch ,
64
72
current_state = init_state ,
65
73
kernel = tuning_hmc ,
66
74
trace_fn = trace_fn ,
67
75
return_final_kernel_results = True ,
68
- seed = seed )
76
+ seed = seed ,
77
+ )
69
78
70
- step_size = jax .tree_multimap (
71
- get_tuned_stepsize , list (init_samples ), tuning_result )
79
+ step_size = jax .tree_multimap (get_tuned_stepsize , list (init_samples ), tuning_result )
72
80
init_state = [x [- 1 ] for x in init_samples ]
73
81
74
82
# Run inference
@@ -79,47 +87,55 @@ def get_tuned_stepsize(samples, step_size):
79
87
current_state = init_state ,
80
88
kernel = sample_kernel ,
81
89
trace_fn = lambda _ , pkr : pkr .inner_results .leapfrogs_taken ,
82
- seed = seed )
83
-
90
+ seed = seed ,
91
+ )
92
+
84
93
return mcmc_samples , leapfrog_num
85
-
94
+
86
95
print ("Compiling..." )
87
96
tic2 = pd .Timestamp .now ()
88
97
map_seed = jax .random .split (seed , chains )
89
98
mcmc_samples , leapfrog_num = _sample (init_state_batched , map_seed )
90
99
tic3 = pd .Timestamp .now ()
91
100
print ("Compilation + sampling time = " , tic3 - tic2 )
92
-
101
+
93
102
# map_seed = jax.random.split(seed, chains)
94
103
# mcmc_samples = _sample(init_state_batched, map_seed)
95
104
# tic4 = pd.Timestamp.now()
96
105
# print("Sampling time = ", tic4 - tic3)
97
-
106
+
98
107
posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
99
108
100
109
az_trace = az .from_dict (posterior = posterior )
101
- return az_trace # , leapfrog_num, tic3 - tic2
110
+ return az_trace # , leapfrog_num, tic3 - tic2
102
111
103
112
import jax
104
113
114
+
105
115
def sample_numpyro_nuts (
106
- draws = 1000 , tune = 1000 , chains = 4 , target_accept = 0.8 , random_seed = 10 , model = None , progress_bar = True ):
116
+ draws = 1000 ,
117
+ tune = 1000 ,
118
+ chains = 4 ,
119
+ target_accept = 0.8 ,
120
+ random_seed = 10 ,
121
+ model = None ,
122
+ progress_bar = True ,
123
+ ):
107
124
from numpyro .infer import MCMC , NUTS
108
125
109
126
from pymc3 import modelcontext
127
+
110
128
model = modelcontext (model )
111
-
129
+
112
130
seed = jax .random .PRNGKey (random_seed )
113
-
131
+
114
132
fgraph = theano .gof .FunctionGraph (model .free_RVs , [model .logpt ])
115
133
fns = theano .sandbox .jaxify .jax_funcify (fgraph )
116
134
logp_fn_jax = fns [0 ]
117
135
118
136
rv_names = [rv .name for rv in model .free_RVs ]
119
137
init_state = [model .test_point [rv_name ] for rv_name in rv_names ]
120
- init_state_batched = jax .tree_map (
121
- lambda x : np .repeat (x [None , ...], chains , axis = 0 ),
122
- init_state )
138
+ init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
123
139
124
140
@jax .jit
125
141
def _sample (current_state , seed ):
@@ -130,30 +146,37 @@ def _sample(current_state, seed):
130
146
target_accept_prob = target_accept ,
131
147
adapt_step_size = True ,
132
148
adapt_mass_matrix = True ,
133
- dense_mass = False )
149
+ dense_mass = False ,
150
+ )
134
151
135
152
pmap_numpyro = MCMC (
136
- nuts_kernel , num_warmup = tune , num_samples = draws , num_chains = chains ,
137
- postprocess_fn = None , chain_method = 'parallel' , progress_bar = progress_bar )
138
-
139
- pmap_numpyro .run (seed , init_params = current_state , extra_fields = ('num_steps' ,))
153
+ nuts_kernel ,
154
+ num_warmup = tune ,
155
+ num_samples = draws ,
156
+ num_chains = chains ,
157
+ postprocess_fn = None ,
158
+ chain_method = "parallel" ,
159
+ progress_bar = progress_bar ,
160
+ )
161
+
162
+ pmap_numpyro .run (seed , init_params = current_state , extra_fields = ("num_steps" ,))
140
163
samples = pmap_numpyro .get_samples (group_by_chain = True )
141
- leapfrogs_taken = pmap_numpyro .get_extra_fields (group_by_chain = True )[' num_steps' ]
164
+ leapfrogs_taken = pmap_numpyro .get_extra_fields (group_by_chain = True )[" num_steps" ]
142
165
return samples , leapfrogs_taken
143
-
166
+
144
167
print ("Compiling..." )
145
168
tic2 = pd .Timestamp .now ()
146
169
map_seed = jax .random .split (seed , chains )
147
170
mcmc_samples , leapfrogs_taken = _sample (init_state_batched , map_seed )
148
171
tic3 = pd .Timestamp .now ()
149
172
print ("Compilation + sampling time = " , tic3 - tic2 )
150
-
173
+
151
174
# map_seed = jax.random.split(seed, chains)
152
175
# mcmc_samples = _sample(init_state_batched, map_seed)
153
176
# tic4 = pd.Timestamp.now()
154
177
# print("Sampling time = ", tic4 - tic3)
155
-
178
+
156
179
posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
157
180
158
181
az_trace = az .from_dict (posterior = posterior )
159
- return az_trace # , leapfrogs_taken, tic3 - tic2
182
+ return az_trace # , leapfrogs_taken, tic3 - tic2
0 commit comments