Skip to content

Commit e6bf0e3

Browse files
Additional fixes requested by ruff
1 parent 9aee194 commit e6bf0e3

File tree

16 files changed

+49
-33
lines changed

16 files changed

+49
-33
lines changed

pymc_experimental/__init__.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515

16+
from pymc_experimental import distributions, gp, statespace, utils
17+
from pymc_experimental.inference.fit import fit
18+
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.model_api import as_model
1620
from pymc_experimental.version import __version__
1721

1822
_log = logging.getLogger("pmx")
@@ -23,7 +27,14 @@
2327
handler = logging.StreamHandler()
2428
_log.addHandler(handler)
2529

26-
from pymc_experimental import distributions, gp, statespace, utils
27-
from pymc_experimental.inference.fit import fit
28-
from pymc_experimental.model.marginal_model import MarginalModel
29-
from pymc_experimental.model.model_api import as_model
30+
31+
__all__ = [
32+
"distributions",
33+
"gp",
34+
"statespace",
35+
"utils",
36+
"fit",
37+
"MarginalModel",
38+
"as_model",
39+
"__version__",
40+
]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2+
3+
__all__ = ["R2D2M2CP"]

pymc_experimental/gp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
17+
18+
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]

pymc_experimental/inference/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.inference.fit import fit
17+
18+
__all__ = ["fit"]

pymc_experimental/inference/fit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from importlib.util import find_spec
1415

1516

1617
def fit(method, **kwargs):
@@ -30,10 +31,8 @@ def fit(method, **kwargs):
3031
arviz.InferenceData
3132
"""
3233
if method == "pathfinder":
33-
try:
34-
import blackjax
35-
except ImportError as exc:
36-
raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc
34+
if find_spec("blackjax") is None:
35+
raise RuntimeError("Need BlackJAX to use `pathfinder`")
3736

3837
from pymc_experimental.inference.pathfinder import fit_pathfinder
3938

pymc_experimental/model/marginal_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytensor import Mode, scan
2222
from pytensor.compile import SharedVariable
2323
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
24+
from pytensor.graph.basic import graph_inputs
2425
from pytensor.graph.replace import graph_replace, vectorize_graph
2526
from pytensor.scan import map as scan_map
2627
from pytensor.tensor import TensorType, TensorVariable
@@ -638,9 +639,6 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
638639
return True
639640

640641

641-
from pytensor.graph.basic import graph_inputs
642-
643-
644642
def collect_shared_vars(outputs, blockers):
645643
return [
646644
inp for inp in graph_inputs(outputs, blockers=blockers) if isinstance(inp, SharedVariable)

pymc_experimental/model/transforms/autoreparam.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ def vip_reparametrize(
419419
lambda_names.append(lam.name)
420420
toposort_replace(fmodel, replacements, reverse=True)
421421
reparam_model = model_from_fgraph(fmodel)
422-
model_lambdas = {n: reparam_model[l] for l, n in zip(lambda_names, var_names)}
422+
model_lambdas = {
423+
var_name: reparam_model[lambda_name]
424+
for lambda_name, var_name in zip(lambda_names, var_names)
425+
}
423426
vip = VIP(model_lambdas)
424427
return reparam_model, vip

pymc_experimental/statespace/core/representation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import copy
22

3-
from typing import Union
4-
53
import numpy as np
64
import pytensor
75
import pytensor.tensor as pt
@@ -12,7 +10,7 @@
1210
)
1311

1412
floatX = pytensor.config.floatX
15-
KeyLike = Union[tuple[str | int, ...], str]
13+
KeyLike = tuple[str | int, ...] | str
1614

1715

1816
class PytensorRepresentation:

pymc_experimental/statespace/filters/distributions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ def update(self, node: Node):
9696

9797

9898
class _LinearGaussianStateSpace(Continuous):
99-
rv_op = LinearGaussianStateSpaceRV
100-
10199
def __new__(
102100
cls,
103101
name,
@@ -360,8 +358,6 @@ def update(self, node: Node):
360358

361359

362360
class SequenceMvNormal(Continuous):
363-
rv_op = KalmanFilterRV
364-
365361
def __new__(cls, *args, **kwargs):
366362
return super().__new__(cls, *args, **kwargs)
367363

pymc_experimental/statespace/filters/kalman_filter.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,14 @@ def handle_missing_values(
351351
self, y, Z, H
352352
) -> tuple[TensorVariable, TensorVariable, TensorVariable, float]:
353353
"""
354-
This function handles missing values in the observation data `y` and adjusts the design matrix `Z` and the
355-
observation noise covariance matrix `H` accordingly. Missing values are replaced with zeros to prevent
356-
propagating NaNs through the computation. The function also returns a binary flag tensor `all_nan_flag`,
357-
indicating if all values in the observation data are missing. This flag is used for numerical adjustments in
358-
the update method.
354+
Handle missing values in the observation data `y`
355+
356+
Adjusts the design matrix `Z` and the observation noise covariance matrix `H` by removing rows and/or columns
357+
associated with the data that is not observed at this iteration. Missing values are replaced with zeros to prevent
358+
propagating NaNs through the computation.
359+
360+
Return a binary flag tensor `all_nan_flag`,indicating if all values in the observation data are missing. This
361+
flag is used for numerical adjustments in the update method.
359362
360363
Parameters
361364
----------
@@ -660,7 +663,7 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag):
660663

661664

662665
class CholeskyFilter(BaseFilter):
663-
""" "
666+
"""
664667
Kalman filter with Cholesky factorization
665668
666669
Kalman filter implementation using a Cholesky factorization plus pt.solve_triangular to (attempt) to speed up
@@ -712,7 +715,7 @@ class SingleTimeseriesFilter(BaseFilter):
712715

713716
# TODO: This class should eventually be made irrelevant by pytensor re-writes.
714717
def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q):
715-
""" "
718+
"""
716719
Wrap the data in an `Assert` `Op` to ensure there is only one observed state.
717720
"""
718721
data = assert_data_is_1d(data, pt.eq(data.shape[1], 1))

pymc_experimental/statespace/models/structural.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from abc import ABC
55
from collections.abc import Sequence
6+
from itertools import pairwise
67
from typing import Any
78

89
import numpy as np
@@ -198,7 +199,7 @@ def make_symbolic_graph(self) -> None:
198199
def _state_slices_from_info(self):
199200
info = self._component_info.copy()
200201
comp_states = np.cumsum([0] + [info["k_states"] for info in info.values()])
201-
state_slices = [slice(i, j) for i, j in zip(comp_states[:-1], comp_states[1:])]
202+
state_slices = [slice(i, j) for i, j in pairwise(comp_states)]
202203

203204
return state_slices
204205

pymc_experimental/utils/linear_cg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def linear_cg(
6565
initial_guess=None,
6666
preconditioner=None,
6767
terminate_cg_by_size=False,
68-
use_eval_tolerange=False,
68+
use_eval_tolerance=False,
6969
):
7070
if initial_guess is None:
7171
initial_guess = np.zeros_like(rhs)

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import itertools
16+
import os
1617

1718
from codecs import open
1819
from os.path import dirname, join, realpath
@@ -64,8 +65,6 @@
6465
extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values())))
6566
extras_require["dev"] = dev_install_reqs
6667

67-
import os
68-
6968

7069
def read_version():
7170
here = os.path.abspath(os.path.dirname(__file__))

tests/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515

1616
from pymc_experimental.distributions import histogram_utils
1717
from pymc_experimental.distributions.histogram_utils import histogram_approximation
18+
19+
__all__ = ["histogram_utils", "histogram_approximation"]

tests/statespace/test_VARMAX.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from itertools import product
1+
from itertools import pairwise, product
22

33
import numpy as np
44
import pandas as pd
@@ -102,7 +102,7 @@ def test_VARMAX_update_matches_statsmodels(data, order, rng):
102102
sm_var = sm.tsa.VARMAX(data, order=(p, q))
103103

104104
param_counts = [None, *np.cumsum(list(sm_var.parameters.values())).tolist()]
105-
param_slices = [slice(a, b) for a, b in zip(param_counts[:-1], param_counts[1:])]
105+
param_slices = [slice(a, b) for a, b in pairwise(param_counts)]
106106
param_lists = [trend, ar, ma, reg, state_cov, obs_cov] = [
107107
sm_var.param_names[idx] for idx in param_slices
108108
]

tests/test_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_predict(fitted_model_instance):
225225
prediction_data = pd.DataFrame({"input": x_pred})
226226
pred = fitted_model_instance.predict(prediction_data["input"])
227227
# Perform elementwise comparison using numpy
228-
assert type(pred) == np.ndarray
228+
assert isinstance(pred, np.ndarray)
229229
assert len(pred) > 0
230230

231231

0 commit comments

Comments
 (0)