34
34
from pymc .model .core import Point
35
35
from pymc .sampling .jax import get_jaxified_graph
36
36
from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
37
+ from pytensor .graph import Apply , Op
37
38
38
39
from pymc_experimental .inference .lbfgs import lbfgs
39
40
@@ -311,7 +312,8 @@ def bfgs_sample(
311
312
alpha ,
312
313
beta ,
313
314
gamma ,
314
- random_seed : RandomSeed | None = None ,
315
+ rng ,
316
+ # random_seed: RandomSeed | None = None,
315
317
):
316
318
# batch: L = 8
317
319
# alpha_l: (N,) => (L, N)
@@ -324,8 +326,6 @@ def bfgs_sample(
324
326
# logdensity: (M,) => (L, M)
325
327
# theta: (J, N)
326
328
327
- rng = pytensor .shared (np .random .default_rng (seed = random_seed ))
328
-
329
329
if not _batched (x , g , alpha , beta , gamma ):
330
330
x = pt .atleast_2d (x )
331
331
g = pt .atleast_2d (g )
@@ -371,6 +371,24 @@ def bfgs_sample(
371
371
return phi , logdensity
372
372
373
373
374
+ class LogLike (Op ):
375
+ def __init__ (self , logp_func ):
376
+ self .logp_func = logp_func
377
+ super ().__init__ ()
378
+
379
+ def make_node (self , phi_node ):
380
+ # Convert inputs to tensor variables
381
+ phi_node = pt .as_tensor (phi_node )
382
+ output_type = pt .tensor (dtype = phi_node .dtype , shape = (None , None ))
383
+ return Apply (self , [phi_node ], [output_type ])
384
+
385
+ def perform (self , node : Apply , phi_node , outputs ) -> None :
386
+ phi_node = phi_node [0 ]
387
+ logp_node = np .apply_along_axis (self .logp_func , axis = - 1 , arr = phi_node )
388
+ # outputs[0][0] = np.asarray(logp)
389
+ outputs [0 ][0 ] = logp_node
390
+
391
+
374
392
def _pymc_pathfinder (
375
393
model ,
376
394
x0 : np .float64 ,
@@ -406,38 +424,43 @@ def neg_dlogp_func(x):
406
424
gtol = gtol ,
407
425
maxls = maxls ,
408
426
)
427
+ x = pytensor .shared (history .x , "x" )
428
+ g = pytensor .shared (history .g , "g" )
409
429
410
- alpha , update_mask = alpha_recover (history .x , history .g )
411
-
412
- beta , gamma = inverse_hessian_factors (alpha , history .x , history .g , update_mask , J = maxcor )
430
+ alpha , update_mask = alpha_recover (x , g )
413
431
414
- phi , logq_phi = bfgs_sample (
432
+ beta , gamma = inverse_hessian_factors (alpha , x , g , update_mask , J = maxcor )
433
+ rng = pytensor .shared (np .random .default_rng (seed = pathfinder_seed ))
434
+ _phi , _logq_phi = bfgs_sample (
415
435
num_samples = num_elbo_draws ,
416
- x = history . x ,
417
- g = history . g ,
436
+ x = x ,
437
+ g = g ,
418
438
alpha = alpha ,
419
439
beta = beta ,
420
440
gamma = gamma ,
421
- random_seed = pathfinder_seed ,
441
+ rng = rng ,
422
442
)
443
+ sample_phi_fn = pytensor .function ([alpha , beta , gamma ], [_phi , _logq_phi ])
444
+ phi , logq_phi = sample_phi_fn (alpha .eval (), beta .eval (), gamma .eval ())
423
445
424
446
# .vectorize is slower than apply_along_axis
425
- logp_phi = np . apply_along_axis (logp_func , axis = - 1 , arr = phi . eval () )
426
- logq_phi = logq_phi . eval ( )
427
- elbo = (logp_phi - logq_phi ). mean ( axis = - 1 )
428
- lstar = np .argmax (elbo )
447
+ loglike = LogLike (logp_func )
448
+ logp_phi = loglike ( phi )
449
+ elbo = pt . mean (logp_phi - logq_phi , axis = - 1 )
450
+ l_star = pt .argmax (elbo )
429
451
452
+ rng .set_value (np .random .default_rng (seed = sample_seed ))
430
453
psi , logq_psi = bfgs_sample (
431
454
num_samples = num_draws ,
432
- x = history . x [ lstar ],
433
- g = history . g [ lstar ],
434
- alpha = alpha [lstar ],
435
- beta = beta [lstar ],
436
- gamma = gamma [lstar ],
437
- random_seed = sample_seed ,
455
+ x = x [ l_star ],
456
+ g = g [ l_star ],
457
+ alpha = alpha [l_star ],
458
+ beta = beta [l_star ],
459
+ gamma = gamma [l_star ],
460
+ rng = rng ,
438
461
)
439
462
440
- return psi [0 ].eval (), logq_psi , logp_func
463
+ return psi [0 ].eval (), logq_psi . eval ()
441
464
442
465
443
466
def fit_pathfinder (
@@ -492,7 +515,7 @@ def fit_pathfinder(
492
515
493
516
# TODO: make better
494
517
if inference_backend == "pymc" :
495
- pathfinder_samples , logq_psi , logp_func = _pymc_pathfinder (
518
+ pathfinder_samples , logq_psi = _pymc_pathfinder (
496
519
model ,
497
520
ip ,
498
521
maxcor = maxcor ,
0 commit comments