|
76 | 76 | _log = logging.getLogger(__name__)
|
77 | 77 |
|
78 | 78 |
|
| 79 | +def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set: |
| 80 | + """Get the set of coords that have remained constant between the trace and model""" |
| 81 | + constant_coords = set() |
| 82 | + for dim, coord in trace_coords.items(): |
| 83 | + current_coord = model.coords.get(dim, None) |
| 84 | + current_length = model.dim_lengths.get(dim, None) |
| 85 | + if hasattr(current_length, "get_value"): |
| 86 | + current_length = current_length.get_value() |
| 87 | + elif hasattr(current_length, "data"): |
| 88 | + current_length = current_length.data |
| 89 | + if ( |
| 90 | + current_coord is not None |
| 91 | + and len(coord) == len(current_coord) |
| 92 | + and np.all(coord == current_coord) |
| 93 | + ) or ( |
| 94 | + # Coord was defined without values (only length) |
| 95 | + current_coord is None and len(coord) == current_length |
| 96 | + ): |
| 97 | + constant_coords.add(dim) |
| 98 | + return constant_coords |
| 99 | + |
| 100 | + |
79 | 101 | def get_vars_in_point_list(trace, model):
|
80 | 102 | """Get the list of Variable instances in the model that have values stored in the trace."""
|
81 | 103 | if not isinstance(trace, MultiTrace):
|
@@ -792,22 +814,7 @@ def sample_posterior_predictive(
|
792 | 814 | stacklevel=2,
|
793 | 815 | )
|
794 | 816 |
|
795 |
| - constant_coords = set() |
796 |
| - for dim, coord in trace_coords.items(): |
797 |
| - current_coord = model.coords.get(dim, None) |
798 |
| - current_length = model.dim_lengths.get(dim, None) |
799 |
| - if hasattr(current_length, 'eval'): |
800 |
| - current_length = current_length.eval() |
801 |
| - if ( |
802 |
| - current_coord is not None |
803 |
| - and len(coord) == len(current_coord) |
804 |
| - and np.all(coord == current_coord) |
805 |
| - ) or ( |
806 |
| - # Coord was defined without values (only length) |
807 |
| - current_coord is None |
808 |
| - and len(coord) == current_length |
809 |
| - ): |
810 |
| - constant_coords.add(dim) |
| 817 | + constant_coords = get_constant_coords(trace_coords, model) |
811 | 818 |
|
812 | 819 | if var_names is not None:
|
813 | 820 | vars_ = [model[x] for x in var_names]
|
|
0 commit comments