15
15
import collections
16
16
import sys
17
17
18
+ from collections .abc import Callable
19
+
18
20
import arviz as az
19
21
import blackjax
20
22
import jax
21
23
import numpy as np
22
24
import pymc as pm
23
25
24
26
from packaging import version
27
+ from pymc import Model
25
28
from pymc .backends .arviz import coords_and_dims_for_inferencedata
26
29
from pymc .blocking import DictToArrayBijection , RaveledVars
30
+ from pymc .initial_point import make_initial_point_fn
27
31
from pymc .model import modelcontext
32
+ from pymc .model .core import Point
28
33
from pymc .sampling .jax import get_jaxified_graph
29
34
from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
30
35
31
36
37
+ def get_jaxified_logp_ravel_inputs (
38
+ model : Model ,
39
+ initial_points : dict | None = None ,
40
+ ) -> tuple [Callable , DictToArrayBijection ]:
41
+ """
42
+ Get jaxified logp function and ravel inputs for a PyMC model.
43
+
44
+ Parameters
45
+ ----------
46
+ model : Model
47
+ PyMC model to jaxify.
48
+
49
+ Returns
50
+ -------
51
+ tuple[Callable, DictToArrayBijection]
52
+ A tuple containing the jaxified logp function and the DictToArrayBijection.
53
+ """
54
+
55
+ new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
56
+ initial_points , (model .logp (),), model .value_vars , ()
57
+ )
58
+
59
+ logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
60
+
61
+ def logprob_fn (x ):
62
+ return logprob_fn_list (x )[0 ]
63
+
64
+ return logprob_fn , DictToArrayBijection .map (initial_points )
65
+
66
+
32
67
def convert_flat_trace_to_idata (
33
68
samples ,
34
69
include_transformed = False ,
@@ -37,7 +72,7 @@ def convert_flat_trace_to_idata(
37
72
):
38
73
model = modelcontext (model )
39
74
ip = model .initial_point ()
40
- ip_point_map_info = pm . blocking . DictToArrayBijection .map (ip ).point_map_info
75
+ ip_point_map_info = DictToArrayBijection .map (ip ).point_map_info
41
76
trace = collections .defaultdict (list )
42
77
for sample in samples :
43
78
raveld_vars = RaveledVars (sample , ip_point_map_info )
@@ -62,10 +97,10 @@ def convert_flat_trace_to_idata(
62
97
63
98
64
99
def fit_pathfinder (
100
+ model = None ,
65
101
num_draws = 1000 ,
66
102
random_seed : RandomSeed | None = None ,
67
103
postprocessing_backend = "cpu" ,
68
- model = None ,
69
104
** pathfinder_kwargs ,
70
105
):
71
106
"""
@@ -99,19 +134,16 @@ def fit_pathfinder(
99
134
100
135
model = modelcontext (model )
101
136
102
- ip = model .initial_point ()
103
- ip_map = DictToArrayBijection .map (ip )
137
+ [jitter_seed , pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 3 )
104
138
105
- new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
106
- ip , (model .logp (),), model .value_vars , ()
139
+ # set initial points. PF requires jittering of initial points
140
+ ipfn = make_initial_point_fn (
141
+ model = model ,
142
+ jitter_rvs = set (model .free_RVs ),
143
+ # TODO: add argument for jitter strategy
107
144
)
108
-
109
- logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
110
-
111
- def logprob_fn (x ):
112
- return logprob_fn_list (x )[0 ]
113
-
114
- [pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 2 )
145
+ ip = Point (ipfn (jitter_seed ), model = model )
146
+ logprob_fn , ip_map = get_jaxified_logp_ravel_inputs (model , initial_points = ip )
115
147
116
148
print ("Running pathfinder..." , file = sys .stdout )
117
149
pathfinder_state , pathfinder_info = blackjax .vi .pathfinder .approximate (
@@ -120,17 +152,17 @@ def logprob_fn(x):
120
152
initial_position = ip_map .data ,
121
153
** pathfinder_kwargs ,
122
154
)
123
-
124
- # retrieved logq
125
- pathfinder_samples , logq = blackjax .vi .pathfinder .sample (
155
+ pathfinder_samples , _ = blackjax .vi .pathfinder .sample (
126
156
rng_key = jax .random .key (sample_seed ),
127
157
state = pathfinder_state ,
128
158
num_samples = num_draws ,
129
159
)
130
160
131
161
idata = convert_flat_trace_to_idata (
162
+ pathfinder_samples ,
132
163
pathfinder_samples ,
133
164
postprocessing_backend = postprocessing_backend ,
134
165
model = model ,
135
166
)
136
167
return pathfinder_state , pathfinder_info , pathfinder_samples , idata
168
+ return pathfinder_state , pathfinder_info , pathfinder_samples , idata
0 commit comments