Skip to content

Commit 694660c

Browse files
committed
type check instead of attr check
1 parent e7264e2 commit 694660c

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pymc/sampling/forward.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
RandomGeneratorSharedVariable,
4343
RandomStateSharedVariable,
4444
)
45-
from pytensor.tensor.sharedvar import SharedVariable
45+
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
46+
from pytensor.tensor.variable import TensorConstant
4647
from rich.console import Console
4748
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4849
from rich.theme import Theme
@@ -82,9 +83,9 @@ def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> se
8283
for dim, coord in trace_coords.items():
8384
current_coord = model.coords.get(dim, None)
8485
current_length = model.dim_lengths.get(dim, None)
85-
if hasattr(current_length, "get_value"):
86+
if isinstance(current_length, TensorSharedVariable):
8687
current_length = current_length.get_value()
87-
elif hasattr(current_length, "data"):
88+
elif isinstance(current_length, TensorConstant):
8889
current_length = current_length.data
8990
if (
9091
current_coord is not None

0 commit comments

Comments
 (0)