1
1
import dataclasses
2
2
import itertools
3
+ import warnings
3
4
from dataclasses import dataclass
4
5
from importlib .util import find_spec
5
6
from math import prod
@@ -200,12 +201,26 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
200
201
logp_numba_raw , c_sig = _make_c_logp_func (
201
202
n_dim , logp_fn , user_data , shared_logp , shared_data
202
203
)
203
- logp_numba = numba .cfunc (c_sig , ** kwargs )(logp_numba_raw )
204
+ with warnings .catch_warnings ():
205
+ warnings .filterwarnings (
206
+ "ignore" ,
207
+ message = "Cannot cache compiled function .* as it uses dynamic globals" ,
208
+ category = numba .NumbaWarning ,
209
+ )
210
+
211
+ logp_numba = numba .cfunc (c_sig , ** kwargs )(logp_numba_raw )
204
212
205
213
expand_numba_raw , c_sig_expand = _make_c_expand_func (
206
214
n_dim , n_expanded , expand_fn , user_data , shared_expand , shared_data
207
215
)
208
- expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
216
+ with warnings .catch_warnings ():
217
+ warnings .filterwarnings (
218
+ "ignore" ,
219
+ message = "Cannot cache compiled function .* as it uses dynamic globals" ,
220
+ category = numba .NumbaWarning ,
221
+ )
222
+
223
+ expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
209
224
210
225
coords = {}
211
226
for name , vals in model .coords .items ():
@@ -276,6 +291,7 @@ def _make_functions(model):
276
291
import pytensor
277
292
import pytensor .link .numba .dispatch
278
293
import pytensor .tensor as pt
294
+ from pymc .pytensorf import compile_pymc
279
295
280
296
shapes = _compute_shapes (model )
281
297
@@ -340,9 +356,8 @@ def _make_functions(model):
340
356
(logp , grad ) = pytensor .graph_replace ([logp , grad ], replacements )
341
357
342
358
# We should avoid compiling the function, and optimize only
343
- logp_fn_pt = pytensor .compile .function .function (
344
- (joined ,), (logp , grad ), mode = pytensor .compile .NUMBA
345
- )
359
+ with model :
360
+ logp_fn_pt = compile_pymc ((joined ,), (logp , grad ), mode = pytensor .compile .NUMBA )
346
361
347
362
logp_fn = logp_fn_pt .vm .jit_fn
348
363
@@ -368,12 +383,13 @@ def _make_functions(model):
368
383
num_expanded = count
369
384
370
385
allvars = pt .concatenate ([joined , * [var .ravel () for var in remaining_rvs ]])
371
- expand_fn_pt = pytensor .compile .function .function (
372
- (joined ,),
373
- (allvars ,),
374
- givens = list (replacements .items ()),
375
- mode = pytensor .compile .NUMBA ,
376
- )
386
+ with model :
387
+ expand_fn_pt = compile_pymc (
388
+ (joined ,),
389
+ (allvars ,),
390
+ givens = list (replacements .items ()),
391
+ mode = pytensor .compile .NUMBA ,
392
+ )
377
393
expand_fn = expand_fn_pt .vm .jit_fn
378
394
379
395
return (
0 commit comments