Skip to content

Commit 593fa0d

Browse files
committed
feat: Filter warnings and compile through pymc
1 parent e77125c commit 593fa0d

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

python/nutpie/compile_pymc.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import itertools
3+
import warnings
34
from dataclasses import dataclass
45
from importlib.util import find_spec
56
from math import prod
@@ -200,12 +201,26 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
200201
logp_numba_raw, c_sig = _make_c_logp_func(
201202
n_dim, logp_fn, user_data, shared_logp, shared_data
202203
)
203-
logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
204+
with warnings.catch_warnings():
205+
warnings.filterwarnings(
206+
"ignore",
207+
message="Cannot cache compiled function .* as it uses dynamic globals",
208+
category=numba.NumbaWarning,
209+
)
210+
211+
logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
204212

205213
expand_numba_raw, c_sig_expand = _make_c_expand_func(
206214
n_dim, n_expanded, expand_fn, user_data, shared_expand, shared_data
207215
)
208-
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
216+
with warnings.catch_warnings():
217+
warnings.filterwarnings(
218+
"ignore",
219+
message="Cannot cache compiled function .* as it uses dynamic globals",
220+
category=numba.NumbaWarning,
221+
)
222+
223+
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
209224

210225
coords = {}
211226
for name, vals in model.coords.items():
@@ -276,6 +291,7 @@ def _make_functions(model):
276291
import pytensor
277292
import pytensor.link.numba.dispatch
278293
import pytensor.tensor as pt
294+
from pymc.pytensorf import compile_pymc
279295

280296
shapes = _compute_shapes(model)
281297

@@ -340,9 +356,8 @@ def _make_functions(model):
340356
(logp, grad) = pytensor.graph_replace([logp, grad], replacements)
341357

342358
# We should avoid compiling the function, and optimize only
343-
logp_fn_pt = pytensor.compile.function.function(
344-
(joined,), (logp, grad), mode=pytensor.compile.NUMBA
345-
)
359+
with model:
360+
logp_fn_pt = compile_pymc((joined,), (logp, grad), mode=pytensor.compile.NUMBA)
346361

347362
logp_fn = logp_fn_pt.vm.jit_fn
348363

@@ -368,12 +383,13 @@ def _make_functions(model):
368383
num_expanded = count
369384

370385
allvars = pt.concatenate([joined, *[var.ravel() for var in remaining_rvs]])
371-
expand_fn_pt = pytensor.compile.function.function(
372-
(joined,),
373-
(allvars,),
374-
givens=list(replacements.items()),
375-
mode=pytensor.compile.NUMBA,
376-
)
386+
with model:
387+
expand_fn_pt = compile_pymc(
388+
(joined,),
389+
(allvars,),
390+
givens=list(replacements.items()),
391+
mode=pytensor.compile.NUMBA,
392+
)
377393
expand_fn = expand_fn_pt.vm.jit_fn
378394

379395
return (

0 commit comments

Comments
 (0)