Skip to content

Commit fa42e7e

Browse files
committed
use np.shape
1 parent 48f7c55 commit fa42e7e

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

pymc3/sampling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Functions for MCMC sampling."""
1616

17-
from typing import Dict, List, Optional, cast, Union, Any, Tuple
17+
from typing import Dict, List, Optional, cast, Union, Any
1818
from typing import Iterable as TIterable
1919
from collections.abc import Iterable
2020
from collections import defaultdict
@@ -1573,10 +1573,7 @@ def insert(self, k: str, v, idx: int):
15731573
ids: int
15741574
The index of the sample we are inserting into the trace.
15751575
"""
1576-
if hasattr(v, "shape"):
1577-
value_shape: Tuple[int, ...] = tuple(v.shape)
1578-
else:
1579-
value_shape = ()
1576+
value_shape = np.shape(v)
15801577

15811578
# initialize if necessary
15821579
if k not in self.trace_dict:

pymc3/tests/test_sampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def test_return_inferencedata(self, monkeypatch):
213213
monkeypatch.setattr("pymc3.__version__", "3.10")
214214
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
215215
result = pm.sample(**kwargs)
216+
pass
216217

217218
@pytest.mark.parametrize("cores", [1, 2])
218219
def test_sampler_stat_tune(self, cores):

0 commit comments

Comments
 (0)