Skip to content

Commit 5aa9544

Browse files
committed
Make jax optional in testing suite
1 parent 7815e5b commit 5aa9544

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ lines-between-types = 1
7070
[tool.ruff.lint.per-file-ignores]
7171
'tests/*.py' = [
7272
'F841', # Unused variable warning for test files -- common in pymc model declarations
73-
'D106' # Missing docstring for public method -- unittest test subclasses don't need docstrings
73+
'D106', # Missing docstring for public method -- unittest test subclasses don't need docstrings
74+
'E402' # Import at top, not respected when pytest.importorskip is required
7475
]
7576
'tests/statespace/*.py' = [
7677
'F401', # Unused import warning for test files -- this check removes imports of fixtures

tests/test_blackjax_smc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import jax
1514
import numpy as np
1615
import pymc as pm
1716
import pytensor.tensor as pt
@@ -21,6 +20,9 @@
2120
from numpy import dtype
2221
from xarray.core.utils import Frozen
2322

23+
jax = pytest.importorskip("jax")
24+
pytest.importorskip("blackjax")
25+
2426
from pymc_experimental.inference.smc.sampling import (
2527
arviz_from_particles,
2628
blackjax_particles_from_pymc_population,

0 commit comments

Comments
 (0)