Skip to content

Commit 292dae9

Browse files
committed
Remove some 'refresh' and make sure progress goes to 100%
1 parent d245a01 commit 292dae9

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

pymc/sampling/forward.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from pytensor.tensor.sharedvar import SharedVariable
4646
from rich.console import Console
47+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4748
from rich.theme import Theme
4849

4950
import pymc as pm
@@ -828,11 +829,21 @@ def sample_posterior_predictive(
828829
# All model variables have a name, but mypy does not know this
829830
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
830831
ppc_trace_t = _DefaultTrace(samples)
832+
833+
progress = CustomProgress(
834+
"[progress.description]{task.description}",
835+
BarColumn(),
836+
"[progress.percentage]{task.percentage:>3.0f}%",
837+
TimeRemainingColumn(),
838+
TextColumn("/"),
839+
TimeElapsedColumn(),
840+
console=Console(theme=progressbar_theme),
841+
disable=not progressbar,
842+
)
843+
831844
try:
832-
with CustomProgress(
833-
console=Console(theme=progressbar_theme), disable=not progressbar
834-
) as progress:
835-
task = progress.add_task("Sampling ...", total=samples)
845+
with progress:
846+
task = progress.add_task("Sampling ...", completed=0, total=samples)
836847
for idx in np.arange(samples):
837848
if nchain > 1:
838849
# the trace object will either be a MultiTrace (and have _straces)...
@@ -854,6 +865,7 @@ def sample_posterior_predictive(
854865
ppc_trace_t.insert(k.name, v, idx)
855866

856867
progress.advance(task)
868+
progress.update(task, refresh=True, completed=samples)
857869

858870
except KeyboardInterrupt:
859871
pass

pymc/sampling/mcmc.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from arviz.data.base import make_attrs
3737
from pytensor.graph.basic import Variable
3838
from rich.console import Console
39+
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
3940
from rich.theme import Theme
4041
from threadpoolctl import threadpool_limits
4142
from typing_extensions import Protocol
@@ -1075,16 +1076,28 @@ def _sample(
10751076
)
10761077
_pbar_data = {"chain": chain, "divergences": 0}
10771078
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1078-
with CustomProgress(
1079-
console=Console(theme=progressbar_theme), disable=not progressbar
1080-
) as progress:
1079+
1080+
progress = CustomProgress(
1081+
"[progress.description]{task.description}",
1082+
BarColumn(),
1083+
"[progress.percentage]{task.percentage:>3.0f}%",
1084+
TimeRemainingColumn(),
1085+
TextColumn("/"),
1086+
TimeElapsedColumn(),
1087+
console=Console(theme=progressbar_theme),
1088+
disable=not progressbar,
1089+
)
1090+
1091+
with progress:
10811092
try:
1082-
task = progress.add_task(_desc.format(**_pbar_data), total=draws)
1093+
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
10831094
for it, diverging in enumerate(sampling_gen):
10841095
if it >= skip_first and diverging:
10851096
_pbar_data["divergences"] += 1
1086-
progress.update(task)
1087-
progress.update(task, refresh=True, advance=1, completed=True)
1097+
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
1098+
progress.update(
1099+
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
1100+
)
10881101
except KeyboardInterrupt:
10891102
pass
10901103

pymc/sampling/parallel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,6 @@ def __iter__(self):
473473
self._completed_draws += 1
474474
if not tuning and stats and stats[0].get("diverging"):
475475
self._divergences += 1
476-
477476
progress.update(
478477
task,
479478
completed=self._completed_draws,

pymc/sampling/population.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def _sample_population(
102102

103103
with CustomProgress(disable=not progressbar) as progress:
104104
task = progress.add_task("[red]Sampling...", total=draws)
105-
106105
for _ in sampling:
107106
progress.update(task)
108107

0 commit comments

Comments
 (0)