|
42 | 42 | RandomGeneratorSharedVariable,
|
43 | 43 | RandomStateSharedVariable,
|
44 | 44 | )
|
45 |
| -from pytensor.tensor.sharedvar import SharedVariable |
| 45 | +from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable |
| 46 | +from pytensor.tensor.variable import TensorConstant |
46 | 47 | from rich.console import Console
|
47 | 48 | from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
|
48 | 49 | from rich.theme import Theme
|
@@ -82,9 +83,9 @@ def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> se
|
82 | 83 | for dim, coord in trace_coords.items():
|
83 | 84 | current_coord = model.coords.get(dim, None)
|
84 | 85 | current_length = model.dim_lengths.get(dim, None)
|
85 |
| - if hasattr(current_length, "get_value"): |
| 86 | + if isinstance(current_length, TensorSharedVariable): |
86 | 87 | current_length = current_length.get_value()
|
87 |
| - elif hasattr(current_length, "data"): |
| 88 | + elif isinstance(current_length, TensorConstant): |
88 | 89 | current_length = current_length.data
|
89 | 90 | if (
|
90 | 91 | current_coord is not None
|
|
0 commit comments