Skip to content

Commit a29a56a

Browse files
committed
refactor: Move threaded sampling to nuts-rs
1 parent 63bb7db commit a29a56a

File tree

10 files changed

+911
-783
lines changed

10 files changed

+911
-783
lines changed

Cargo.lock

Lines changed: 555 additions & 136 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@ name = "_lib"
2121
crate-type = ["cdylib"]
2222

2323
[dependencies]
24-
nuts-rs = "0.8.0"
25-
numpy = "0.20.0"
24+
nuts-rs = "0.9.0"
25+
numpy = "0.21.0"
2626
ndarray = "0.15.6"
2727
rand = "0.8.5"
2828
thiserror = "1.0.44"
2929
rand_chacha = "0.3.1"
3030
rayon = "1.9.0"
31-
arrow2 = "0.17.0"
31+
arrow2 = "0.18.0"
3232
anyhow = "1.0.72"
3333
itertools = "0.12.0"
3434
bridgestan = "2.1.2"
3535
rand_distr = "0.4.3"
3636
smallvec = "1.11.0"
3737

3838
[dependencies.pyo3]
39-
version = "0.20.0"
39+
version = "0.21.0"
4040
features = ["extension-module", "anyhow"]
4141

4242
[dev-dependencies]

python/nutpie/compile_pymc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def with_data(self, **updates):
8787
user_data=user_data,
8888
)
8989

90-
def _make_sampler(self, settings, init_mean, chains, cores, seed, callback=None):
90+
def _make_sampler(self, settings, init_mean, cores, callback=None):
9191
model = self._make_model(init_mean)
92-
return _lib.PySampler.from_pymc(settings, chains, cores, model, seed, callback)
92+
return _lib.PySampler.from_pymc(settings, cores, model, callback)
9393

9494
def _make_model(self, init_mean):
9595
expand_fn = _lib.ExpandFunc(

python/nutpie/compile_stan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def _make_model(self, init_mean):
8080
return self.with_data().model
8181
return self.model
8282

83-
def _make_sampler(self, settings, init_mean, chains, cores, seed, callback=None):
83+
def _make_sampler(self, settings, init_mean, cores, callback=None):
8484
model = self._make_model(init_mean)
85-
return _lib.PySampler.from_stan(settings, chains, cores, model, seed, callback)
85+
return _lib.PySampler.from_stan(settings, cores, model, callback)
8686

8787
@property
8888
def n_dim(self):

python/nutpie/sample.py

Lines changed: 98 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,87 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
128128
)
129129

130130

131+
class _ChainProgress:
132+
bar: Any
133+
total: int
134+
finished: bool
135+
tuning: bool
136+
draws: int
137+
started: bool
138+
chain_id: int
139+
num_divs: int
140+
step_size: float
141+
num_steps: int
142+
143+
def __init__(self, total, chain_id):
144+
self.bar = fastprogress.progress_bar(range(total))
145+
self.total = total
146+
self.finished = False
147+
self.tuning = True
148+
self.draws = 0
149+
self.started = False
150+
self.chain_id = chain_id
151+
self.num_divs = 0
152+
self.step_size = 0.
153+
self.num_steps = 0
154+
self.bar.update(0)
155+
156+
def callback(self, info):
157+
try:
158+
if self.finished:
159+
return
160+
161+
if info.finished_draws == info.total_draws:
162+
self.finished = True
163+
164+
if not info.tuning:
165+
self.tuning = False
166+
167+
self.draws = info.finished_draws
168+
if info.started:
169+
self.started = True
170+
171+
self.num_divs = info.divergences
172+
self.step_size = info.step_size
173+
self.num_steps = info.num_steps
174+
175+
if self.tuning:
176+
state = "warmup"
177+
else:
178+
state = "sampling"
179+
180+
self.bar.comment = (
181+
f"Chain {self.chain_id:2} {state}: "
182+
f"trajectory {self.num_steps:3} / "
183+
f"diverging {self.num_divs: 2} / "
184+
f"step {self.step_size:.2g}"
185+
)
186+
self.bar.update(self.draws)
187+
except Exception as e:
188+
print(e)
189+
190+
191+
class _DetailedProgress:
192+
chains: list[_ChainProgress]
193+
194+
def __init__(self, total_draws, num_chains):
195+
self.chains = [_ChainProgress(total_draws, i) for i in range(num_chains)]
196+
197+
def callback(self, info):
198+
for chain, chain_info in zip(self.chains, info):
199+
chain.callback(chain_info)
200+
201+
202+
class _SummaryProgress:
203+
bar: Any
204+
205+
def __init__(self, total_draws, num_chains):
206+
pass
207+
208+
def callback(self, info):
209+
return None
210+
211+
131212
class _BackgroundSampler:
132213
_sampler: Any
133214
_num_divs: int
@@ -144,81 +225,30 @@ def __init__(
144225
compiled_model,
145226
settings,
146227
init_mean,
147-
chains,
148228
cores,
149-
seed,
150-
draws,
151-
tune,
152229
*,
153230
progress_bar=True,
154231
save_warmup=True,
155232
return_raw_trace=False,
156233
):
157-
self._num_divs = 0
158-
self._tune = settings.num_tune
159-
self._draws = settings.num_draws
160234
self._settings = settings
161-
self._chains_tuning = chains
162-
self._chains_finished = 0
163-
self._chains = chains
164235
self._compiled_model = compiled_model
165236
self._save_warmup = save_warmup
166237
self._return_raw_trace = return_raw_trace
167-
total_draws = (self._draws + self._tune) * self._chains
168-
self._progress = fastprogress.progress_bar(
169-
range(total_draws),
170-
total=total_draws,
171-
display=progress_bar,
172-
)
173-
# fastprogress seems to reset the progress bar
174-
# if we create a new iterator, but we don't want
175-
# this for multiple calls to wait.
176-
self._bar = iter(self._progress)
177-
178-
self._exit_event = Event()
179-
self._pause_event = Event()
180-
self._continue = Condition()
181-
182-
self._finished_draws = 0
183-
184-
next(self._bar)
185-
186-
def progress_callback(info):
187-
if info.draw == self._tune - 1:
188-
self._chains_tuning -= 1
189-
if info.draw == self._tune + self._draws - 1:
190-
self._chains_finished += 1
191-
if info.is_diverging and info.draw > self._tune:
192-
self._num_divs += 1
193-
if self._chains_tuning > 0:
194-
count = self._chains_tuning
195-
divs = self._num_divs
196-
self._progress.comment = (
197-
f" Chains in warmup: {count}, Divergences: {divs}"
198-
)
199-
else:
200-
count = self._chains - self._chains_finished
201-
divs = self._num_divs
202-
self._progress.comment = (
203-
f" Sampling chains: {count}, Divergences: {divs}"
204-
)
205-
try:
206-
next(self._bar)
207-
except StopIteration:
208-
pass
209-
self._finished_draws += 1
238+
239+
total_draws = settings.num_draws + settings.num_tune
210240

211241
if progress_bar:
212-
callback = progress_callback
242+
self._progress = _DetailedProgress(total_draws, settings.num_chains)
243+
callback = self._progress.callback
213244
else:
245+
self._progress = None
214246
callback = None
215247

216248
self._sampler = compiled_model._make_sampler(
217249
settings,
218250
init_mean,
219-
chains,
220251
cores,
221-
seed,
222252
callback=callback,
223253
)
224254

@@ -233,12 +263,10 @@ def wait(self, *, timeout=None):
233263
This resumes the sampler in case it had been paused.
234264
"""
235265
self._sampler.wait(timeout)
236-
self._sampler.finalize()
237-
return self._extract()
238-
239-
def _extract(self):
240266
results = self._sampler.extract_results()
267+
return self._extract(results)
241268

269+
def _extract(self, results):
242270
dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()}
243271
dims["mass_matrix_inv"] = ["unconstrained_parameter"]
244272
dims["gradient"] = ["unconstrained_parameter"]
@@ -253,7 +281,7 @@ def _extract(self):
253281
else:
254282
return _trace_to_arviz(
255283
results,
256-
self._tune,
284+
self._settings.num_tune,
257285
self._compiled_model.shapes,
258286
dims=dims,
259287
coords={
@@ -263,6 +291,11 @@ def _extract(self):
263291
save_warmup=self._save_warmup,
264292
)
265293

294+
def inspect(self):
295+
"""Get a copy of the current state of the trace"""
296+
results = self._sampler.inspect()
297+
return self._extract(results)
298+
266299
def pause(self):
267300
"""Pause the sampler."""
268301
self._sampler.pause()
@@ -278,7 +311,8 @@ def is_finished(self):
278311
def abort(self):
279312
"""Abort sampling and return the trace produced so far."""
280313
self._sampler.abort()
281-
return self._extract()
314+
results = self._sampler.extract_results()
315+
return self._extract(results)
282316

283317
def cancel(self):
284318
"""Abort sampling and discard progress."""
@@ -402,9 +436,10 @@ def sample(
402436
trace : arviz.InferenceData
403437
An ArviZ ``InferenceData`` object that contains the samples.
404438
"""
405-
settings = _lib.PySamplerArgs()
439+
settings = _lib.PyDiagGradNutsSettings(seed)
406440
settings.num_tune = tune
407441
settings.num_draws = draws
442+
settings.num_chains = chains
408443

409444
for name, val in kwargs.items():
410445
setattr(settings, name, val)
@@ -424,11 +459,7 @@ def sample(
424459
compiled_model,
425460
settings,
426461
init_mean,
427-
chains,
428462
cores,
429-
seed,
430-
draws,
431-
tune,
432463
progress_bar=progress_bar,
433464
save_warmup=save_warmup,
434465
return_raw_trace=return_raw_trace,

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
mod pymc;
2-
mod sampler;
32
mod stan;
43
mod wrapper;
54

0 commit comments

Comments
 (0)