Skip to content

Commit 663a60a

Browse files
committed
changed pathfinder samples argument to num_draws
1 parent 8835cd5 commit 663a60a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pymc_experimental/inference/pathfinder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def convert_flat_trace_to_idata(
6262

6363

6464
def fit_pathfinder(
65-
samples=1000,
65+
num_draws=1000,
6666
random_seed: RandomSeed | None = None,
6767
postprocessing_backend="cpu",
6868
model=None,
@@ -125,12 +125,12 @@ def logprob_fn(x):
125125
pathfinder_samples, logq = blackjax.vi.pathfinder.sample(
126126
rng_key=jax.random.key(sample_seed),
127127
state=pathfinder_state,
128-
num_samples=samples,
128+
num_samples=num_draws,
129129
)
130130

131131
idata = convert_flat_trace_to_idata(
132132
pathfinder_samples,
133133
postprocessing_backend=postprocessing_backend,
134134
model=model,
135135
)
136-
return pathfinder_state, pathfinder_info, pathfinder_samples, logq, idata
136+
return pathfinder_state, pathfinder_info, pathfinder_samples, idata

0 commit comments

Comments
 (0)