Skip to content

Commit 507de2f

Browse files
committed
util.py related fixes
1 parent 870db74 commit 507de2f

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pymc/tests/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def fn(a=UNSET):
154154
def test_dataset_to_point_list():
155155
ds = xarray.Dataset()
156156
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
157-
pl = dataset_to_point_list(ds)
157+
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
158158
assert isinstance(pl, list)
159159
assert len(pl) == 6
160160
assert isinstance(pl[0], dict)
@@ -163,4 +163,4 @@ def test_dataset_to_point_list():
163163
# Check that non-str keys are caught
164164
ds[3] = xarray.DataArray([1, 2, 3])
165165
with pytest.raises(ValueError, match="must be str"):
166-
dataset_to_point_list(ds)
166+
dataset_to_point_list(ds, sample_dims=["chain", "draw"])

pymc/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616

17-
from typing import Dict, List, cast
17+
from typing import Any, Dict, List, Tuple, cast
1818

1919
import cloudpickle
2020
import numpy as np
@@ -230,7 +230,9 @@ def enhanced(*args, **kwargs):
230230
return enhanced
231231

232232

233-
def dataset_to_point_list(ds: xarray.Dataset, sample_dims: List) -> List[Dict[str, np.ndarray]]:
233+
def dataset_to_point_list(
234+
ds: xarray.Dataset, sample_dims: List
235+
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
234236
# All keys of the dataset must be a str
235237
var_names = list(ds.keys())
236238
for vn in var_names:

0 commit comments

Comments
 (0)