Skip to content

Commit c3e8210

Browse files
Disallow "__sample__" as a dimension name
Co-authored-by: Thomas Wiecki <[email protected]>
1 parent ad9b919 commit c3e8210

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

pymc3/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,17 +1000,18 @@ def add_coord(
10001000
----------
10011001
name : str
10021002
Name of the dimension.
1003-
Forbidden: {"chain", "draw"}
1003+
Forbidden: {"chain", "draw", "__sample__"}
10041004
values : optional, array-like
10051005
Coordinate values or ``None`` (for auto-numbering).
10061006
If ``None`` is passed, a ``length`` must be specified.
10071007
length : optional, scalar
10081008
A symbolic scalar of the dimensions length.
10091009
Defaults to ``aesara.shared(len(values))``.
10101010
"""
1011-
if name in {"draw", "chain"}:
1011+
if name in {"draw", "chain", "__sample__"}:
10121012
raise ValueError(
1013-
"Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs."
1013+
"Dimensions can not be named `draw`, `chain` or `__sample__`, "
1014+
"as those are reserved for use in `InferenceData`."
10141015
)
10151016
if values is None and length is None:
10161017
raise ValueError(
@@ -1022,7 +1023,7 @@ def add_coord(
10221023
)
10231024
if name in self.coords:
10241025
if not values.equals(self.coords[name]):
1025-
raise ValueError("Duplicate and incompatiple coordinate: %s." % name)
1026+
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
10261027
else:
10271028
self._coords[name] = values
10281029
self._dim_lengths[name] = length or aesara.shared(len(values))

pymc3/tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_return_inferencedata(self, monkeypatch):
214214
return_inferencedata=True,
215215
discard_tuned_samples=True,
216216
idata_kwargs={"prior": prior},
217-
random_seed=-1
217+
random_seed=-1,
218218
)
219219
assert "prior" in result
220220
assert isinstance(result, InferenceData)

0 commit comments

Comments
 (0)