Skip to content

Commit 7e1d356

Browse files
committed
Black and isort.
1 parent d4da86c commit 7e1d356

File tree

1 file changed

+68
-45
lines changed

1 file changed

+68
-45
lines changed

pymc3/sampling_jax.py

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
import warnings
21
import os
32
import re
4-
xla_flags = os.getenv('XLA_FLAGS', '').lstrip('--')
5-
xla_flags = re.sub(r'xla_force_host_platform_device_count=.+\s', '', xla_flags).split()
6-
os.environ['XLA_FLAGS'] = ' '.join(['--xla_force_host_platform_device_count={}'.format(100)])
3+
import warnings
74

5+
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
6+
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
7+
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
8+
9+
import arviz as az
10+
import jax
811
import numpy as np
912
import pandas as pd
10-
1113
import theano
1214
import theano.sandbox.jax_linker
1315
import theano.sandbox.jaxify
14-
import jax
1516

16-
import arviz as az
1717
import pymc3 as pm
18+
1819
from pymc3 import modelcontext
1920

2021
warnings.warn("This module is experimental.")
@@ -24,51 +25,58 @@
2425
# This will make the JAX Linker the default
2526
# theano.config.mode = "JAX"
2627

27-
def sample_tfp_nuts(draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None,
28-
num_tuning_epoch=2, num_compute_step_size=500):
28+
29+
def sample_tfp_nuts(
30+
draws=1000,
31+
tune=1000,
32+
chains=4,
33+
target_accept=0.8,
34+
random_seed=10,
35+
model=None,
36+
num_tuning_epoch=2,
37+
num_compute_step_size=500,
38+
):
2939
from tensorflow_probability.substrates import jax as tfp
40+
3041
model = modelcontext(model)
31-
42+
3243
seed = jax.random.PRNGKey(random_seed)
33-
44+
3445
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
3546
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
3647
logp_fn_jax = fns[0]
3748

3849
rv_names = [rv.name for rv in model.free_RVs]
3950
init_state = [model.test_point[rv_name] for rv_name in rv_names]
40-
init_state_batched = jax.tree_map(
41-
lambda x: np.repeat(x[None, ...], chains, axis=0),
42-
init_state)
51+
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
4352

4453
@jax.pmap
4554
def _sample(init_state, seed):
4655
def gen_kernel(step_size):
47-
hmc = tfp.mcmc.NoUTurnSampler(
48-
target_log_prob_fn=logp_fn_jax, step_size=step_size)
56+
hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size)
4957
return tfp.mcmc.DualAveragingStepSizeAdaptation(
50-
hmc, tune // num_tuning_epoch,
51-
target_accept_prob=target_accept)
58+
hmc, tune // num_tuning_epoch, target_accept_prob=target_accept
59+
)
5260

5361
def trace_fn(_, pkr):
5462
return pkr.new_step_size
55-
63+
5664
def get_tuned_stepsize(samples, step_size):
5765
return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:])
5866

5967
step_size = jax.tree_map(jax.numpy.ones_like, init_state)
60-
for i in range(num_tuning_epoch-1):
68+
for i in range(num_tuning_epoch - 1):
6169
tuning_hmc = gen_kernel(step_size)
6270
init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain(
6371
num_results=tune // num_tuning_epoch,
6472
current_state=init_state,
6573
kernel=tuning_hmc,
6674
trace_fn=trace_fn,
6775
return_final_kernel_results=True,
68-
seed=seed)
76+
seed=seed,
77+
)
6978

70-
step_size = jax.tree_multimap(
71-
get_tuned_stepsize, list(init_samples), tuning_result)
79+
step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result)
7280
init_state = [x[-1] for x in init_samples]
7381

7482
# Run inference
@@ -79,47 +87,55 @@ def get_tuned_stepsize(samples, step_size):
7987
current_state=init_state,
8088
kernel=sample_kernel,
8189
trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken,
82-
seed=seed)
83-
90+
seed=seed,
91+
)
92+
8493
return mcmc_samples, leapfrog_num
85-
94+
8695
print("Compiling...")
8796
tic2 = pd.Timestamp.now()
8897
map_seed = jax.random.split(seed, chains)
8998
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)
9099
tic3 = pd.Timestamp.now()
91100
print("Compilation + sampling time = ", tic3 - tic2)
92-
101+
93102
# map_seed = jax.random.split(seed, chains)
94103
# mcmc_samples = _sample(init_state_batched, map_seed)
95104
# tic4 = pd.Timestamp.now()
96105
# print("Sampling time = ", tic4 - tic3)
97-
106+
98107
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
99108

100109
az_trace = az.from_dict(posterior=posterior)
101-
return az_trace #, leapfrog_num, tic3 - tic2
110+
return az_trace # , leapfrog_num, tic3 - tic2
102111

103112
import jax
104113

114+
105115
def sample_numpyro_nuts(
106-
draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, model=None, progress_bar=True):
116+
draws=1000,
117+
tune=1000,
118+
chains=4,
119+
target_accept=0.8,
120+
random_seed=10,
121+
model=None,
122+
progress_bar=True,
123+
):
107124
from numpyro.infer import MCMC, NUTS
108125

109126
from pymc3 import modelcontext
127+
110128
model = modelcontext(model)
111-
129+
112130
seed = jax.random.PRNGKey(random_seed)
113-
131+
114132
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
115133
fns = theano.sandbox.jaxify.jax_funcify(fgraph)
116134
logp_fn_jax = fns[0]
117135

118136
rv_names = [rv.name for rv in model.free_RVs]
119137
init_state = [model.test_point[rv_name] for rv_name in rv_names]
120-
init_state_batched = jax.tree_map(
121-
lambda x: np.repeat(x[None, ...], chains, axis=0),
122-
init_state)
138+
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
123139

124140
@jax.jit
125141
def _sample(current_state, seed):
@@ -130,30 +146,37 @@ def _sample(current_state, seed):
130146
target_accept_prob=target_accept,
131147
adapt_step_size=True,
132148
adapt_mass_matrix=True,
133-
dense_mass=False)
149+
dense_mass=False,
150+
)
134151

135152
pmap_numpyro = MCMC(
136-
nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains,
137-
postprocess_fn=None, chain_method='parallel', progress_bar=progress_bar)
138-
139-
pmap_numpyro.run(seed, init_params=current_state, extra_fields=('num_steps',))
153+
nuts_kernel,
154+
num_warmup=tune,
155+
num_samples=draws,
156+
num_chains=chains,
157+
postprocess_fn=None,
158+
chain_method="parallel",
159+
progress_bar=progress_bar,
160+
)
161+
162+
pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",))
140163
samples = pmap_numpyro.get_samples(group_by_chain=True)
141-
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)['num_steps']
164+
leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"]
142165
return samples, leapfrogs_taken
143-
166+
144167
print("Compiling...")
145168
tic2 = pd.Timestamp.now()
146169
map_seed = jax.random.split(seed, chains)
147170
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
148171
tic3 = pd.Timestamp.now()
149172
print("Compilation + sampling time = ", tic3 - tic2)
150-
173+
151174
# map_seed = jax.random.split(seed, chains)
152175
# mcmc_samples = _sample(init_state_batched, map_seed)
153176
# tic4 = pd.Timestamp.now()
154177
# print("Sampling time = ", tic4 - tic3)
155-
178+
156179
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
157180

158181
az_trace = az.from_dict(posterior=posterior)
159-
return az_trace #, leapfrogs_taken, tic3 - tic2
182+
return az_trace # , leapfrogs_taken, tic3 - tic2

0 commit comments

Comments
 (0)