15
15
import collections
16
16
import sys
17
17
18
+ from collections .abc import Callable
19
+
18
20
import arviz as az
19
21
import blackjax
20
22
import jax
21
23
import numpy as np
22
24
import pymc as pm
23
25
24
26
from packaging import version
27
+ from pymc import Model
25
28
from pymc .backends .arviz import coords_and_dims_for_inferencedata
26
29
from pymc .blocking import DictToArrayBijection , RaveledVars
30
+ from pymc .initial_point import make_initial_point_fn
27
31
from pymc .model import modelcontext
32
+ from pymc .model .core import Point
28
33
from pymc .sampling .jax import get_jaxified_graph
29
34
from pymc .util import RandomSeed , _get_seeds_per_chain , get_default_varnames
30
35
31
36
37
+ def get_jaxified_logp_ravel_inputs (
38
+ model : Model ,
39
+ initial_points : dict | None = None ,
40
+ ) -> tuple [Callable , DictToArrayBijection ]:
41
+ """
42
+ Get jaxified logp function and ravel inputs for a PyMC model.
43
+
44
+ Parameters
45
+ ----------
46
+ model : Model
47
+ PyMC model to jaxify.
48
+
49
+ Returns
50
+ -------
51
+ tuple[Callable, DictToArrayBijection]
52
+ A tuple containing the jaxified logp function and the DictToArrayBijection.
53
+ """
54
+
55
+ new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
56
+ initial_points , (model .logp (),), model .value_vars , ()
57
+ )
58
+
59
+ logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
60
+
61
+ def logprob_fn (x ):
62
+ return logprob_fn_list (x )[0 ]
63
+
64
+ return logprob_fn , DictToArrayBijection .map (initial_points )
65
+
66
+
32
67
def convert_flat_trace_to_idata (
33
68
samples ,
34
69
include_transformed = False ,
@@ -37,7 +72,7 @@ def convert_flat_trace_to_idata(
37
72
):
38
73
model = modelcontext (model )
39
74
ip = model .initial_point ()
40
- ip_point_map_info = pm . blocking . DictToArrayBijection .map (ip ).point_map_info
75
+ ip_point_map_info = DictToArrayBijection .map (ip ).point_map_info
41
76
trace = collections .defaultdict (list )
42
77
for sample in samples :
43
78
raveld_vars = RaveledVars (sample , ip_point_map_info )
@@ -62,10 +97,10 @@ def convert_flat_trace_to_idata(
62
97
63
98
64
99
def fit_pathfinder (
65
- num_samples = 1000 ,
100
+ model = None ,
101
+ num_draws = 1000 ,
66
102
random_seed : RandomSeed | None = None ,
67
103
postprocessing_backend = "cpu" ,
68
- model = None ,
69
104
** pathfinder_kwargs ,
70
105
):
71
106
"""
@@ -99,22 +134,19 @@ def fit_pathfinder(
99
134
100
135
model = modelcontext (model )
101
136
102
- ip = model .initial_point ()
103
- ip_map = DictToArrayBijection .map (ip )
137
+ [jitter_seed , pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 3 )
104
138
105
- new_logprob , new_input = pm .pytensorf .join_nonshared_inputs (
106
- ip , (model .logp (),), model .value_vars , ()
139
+ # set initial points. PF requires jittering of initial points
140
+ ipfn = make_initial_point_fn (
141
+ model = model ,
142
+ jitter_rvs = set (model .free_RVs ),
143
+ # TODO: add argument for jitter strategy
107
144
)
108
-
109
- logprob_fn_list = get_jaxified_graph ([new_input ], new_logprob )
110
-
111
- def logprob_fn (x ):
112
- return logprob_fn_list (x )[0 ]
113
-
114
- [pathfinder_seed , sample_seed ] = _get_seeds_per_chain (random_seed , 2 )
145
+ ip = Point (ipfn (jitter_seed ), model = model )
146
+ logprob_fn , ip_map = get_jaxified_logp_ravel_inputs (model , initial_points = ip )
115
147
116
148
print ("Running pathfinder..." , file = sys .stdout )
117
- pathfinder_state , _ = blackjax .vi .pathfinder .approximate (
149
+ pathfinder_state , pathfinder_info = blackjax .vi .pathfinder .approximate (
118
150
rng_key = jax .random .key (pathfinder_seed ),
119
151
logdensity_fn = logprob_fn ,
120
152
initial_position = ip_map .data ,
@@ -123,12 +155,12 @@ def logprob_fn(x):
123
155
pathfinder_samples , _ = blackjax .vi .pathfinder .sample (
124
156
rng_key = jax .random .key (sample_seed ),
125
157
state = pathfinder_state ,
126
- num_samples = num_samples ,
158
+ num_samples = num_draws ,
127
159
)
128
160
129
161
idata = convert_flat_trace_to_idata (
130
162
pathfinder_samples ,
131
163
postprocessing_backend = postprocessing_backend ,
132
164
model = model ,
133
165
)
134
- return idata
166
+ return pathfinder_state , pathfinder_info , pathfinder_samples , idata
0 commit comments