Skip to content

Fix CI #5683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2022
Merged

Fix CI #5683

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
exclude: ^requirements-dev\.txt$
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.941
rev: v0.942
hooks:
- id: mypy
name: Run static type checks
Expand Down Expand Up @@ -42,11 +42,11 @@ repos:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/pylint
rev: v2.12.2
rev: v2.13.2
hooks:
- id: pylint
args: [--rcfile=.pylintrc]
Expand Down
10 changes: 7 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
Expand Down Expand Up @@ -178,7 +179,10 @@ def __init__(
" one of trace, prior, posterior_predictive or predictions."
)

untyped_coords = {**self.model.coords, **(coords or {})}
# Make coord types more rigid
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
if coords:
untyped_coords.update(coords)
self.coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in untyped_coords.items()
Expand Down Expand Up @@ -649,8 +653,8 @@ def predictions_to_inference_data(
)
if hasattr(idata_orig, "posterior"):
assert idata_orig is not None
converter.nchains = idata_orig.posterior.dims["chain"]
converter.ndraws = idata_orig.posterior.dims["draw"]
converter.nchains = idata_orig["posterior"].dims["chain"]
converter.ndraws = idata_orig["posterior"].dims["draw"]
else:
aelem = next(iter(predictions.values()))
converter.nchains, converter.ndraws = aelem.shape[:2]
Expand Down
8 changes: 4 additions & 4 deletions pymc/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
self._add_warnings([warn])
return

if idata.posterior.sizes["chain"] == 1:
if idata["posterior"].sizes["chain"] == 1:
msg = (
"Only one chain was sampled, this makes it impossible to "
"run some convergence checks"
Expand All @@ -124,7 +124,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
self._add_warnings([warn])
return

elif idata.posterior.sizes["chain"] < 4:
elif idata["posterior"].sizes["chain"] < 4:
msg = (
"We recommend running at least 4 chains for robust computation of "
"convergence diagnostics"
Expand All @@ -140,7 +140,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
if is_transformed_name(rv_name):
rv_name2 = get_untransformed_name(rv_name)
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
if rv_name in idata.posterior:
if rv_name in idata["posterior"]:
varnames.append(rv_name)

self._ess = ess = arviz.ess(idata, var_names=varnames)
Expand All @@ -158,7 +158,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
warnings.append(warn)

eff_min = min(val.min() for val in ess.values())
eff_per_chain = eff_min / idata.posterior.sizes["chain"]
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
if eff_per_chain < 100:
msg = (
"The effective sample size per chain is smaller than 100 for some parameters. "
Expand Down
28 changes: 14 additions & 14 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def sample(
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
f"took {mtrace.report.t_sampling:.0f} seconds."
)
mtrace.report._log_summary()

idata = None
if compute_convergence_checks or return_inferencedata:
Expand All @@ -622,19 +623,18 @@ def sample(
ikwargs.update(idata_kwargs)
idata = pm.to_inference_data(mtrace, **ikwargs)

if compute_convergence_checks:
if draws - tune < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.", stacklevel=2
)
else:
mtrace.report._run_convergence_checks(idata, model)
mtrace.report._log_summary()
if compute_convergence_checks:
if draws - tune < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
mtrace.report._run_convergence_checks(idata, model)

if return_inferencedata:
return idata
else:
return mtrace
if return_inferencedata:
return idata
return mtrace


def _check_start_shape(model, start: PointType):
Expand Down Expand Up @@ -1621,7 +1621,7 @@ def sample_posterior_predictive(
_trace: Union[MultiTrace, PointList]
nchain: int
if isinstance(trace, InferenceData):
_trace = dataset_to_point_list(trace.posterior)
_trace = dataset_to_point_list(trace["posterior"])
nchain, len_trace = chains_and_samples(trace)
elif isinstance(trace, xarray.Dataset):
_trace = dataset_to_point_list(trace)
Expand Down Expand Up @@ -1704,7 +1704,7 @@ def sample_posterior_predictive(

if not vars_to_sample:
if return_inferencedata and not extend_inferencedata:
return None
return InferenceData()
elif return_inferencedata and extend_inferencedata:
return trace
return {}
Expand Down
2 changes: 1 addition & 1 deletion pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tupl
if isinstance(data, xarray.Dataset):
dataset = data
elif isinstance(data, arviz.InferenceData):
dataset = data.posterior
dataset = data["posterior"]
else:
raise ValueError(
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
Expand Down