10
10
xla_flags = re .sub (r"--xla_force_host_platform_device_count=.+\s" , "" , xla_flags ).split ()
11
11
os .environ ["XLA_FLAGS" ] = " " .join ([f"--xla_force_host_platform_device_count={ 100 } " ] + xla_flags )
12
12
13
+ from datetime import datetime
14
+
13
15
import aesara .tensor as at
14
16
import arviz as az
15
17
import jax
16
18
import numpy as np
17
- import pandas as pd
18
19
19
20
from aeppl .logprob import CheckParameterValue
20
21
from aesara .compile import SharedVariable , Supervisor , mode
@@ -174,7 +175,7 @@ def sample_numpyro_nuts(
174
175
else :
175
176
dims = {}
176
177
177
- tic1 = pd . Timestamp .now ()
178
+ tic1 = datetime .now ()
178
179
print ("Compiling..." , file = sys .stdout )
179
180
180
181
rv_names = [rv .name for rv in model .value_vars ]
@@ -202,7 +203,7 @@ def sample_numpyro_nuts(
202
203
progress_bar = progress_bar ,
203
204
)
204
205
205
- tic2 = pd . Timestamp .now ()
206
+ tic2 = datetime .now ()
206
207
print ("Compilation time = " , tic2 - tic1 , file = sys .stdout )
207
208
208
209
print ("Sampling..." , file = sys .stdout )
@@ -231,7 +232,7 @@ def sample_numpyro_nuts(
231
232
232
233
raw_mcmc_samples = pmap_numpyro .get_samples (group_by_chain = True )
233
234
234
- tic3 = pd . Timestamp .now ()
235
+ tic3 = datetime .now ()
235
236
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
236
237
237
238
print ("Transforming variables..." , file = sys .stdout )
@@ -241,7 +242,7 @@ def sample_numpyro_nuts(
241
242
result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
242
243
mcmc_samples [v .name ] = result
243
244
244
- tic4 = pd . Timestamp .now ()
245
+ tic4 = datetime .now ()
245
246
print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
246
247
247
248
if idata_kwargs is None :
0 commit comments