Skip to content

Commit 95db980

Browse files
ricardoV94jessegrabowski
authored andcommitted
Avoid buggy version of JAX in CI
Related to jax-ml/jax#22751
1 parent f9bf218 commit 95db980

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ jobs:
156156
157157
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
158158
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159-
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
159+
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "jax!=0.4.31" "jaxlib!=0.4.31" numpyro && pip install tensorflow-probability; fi
160160
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
161161
pip install pytest-sphinx
162162

0 commit comments

Comments
 (0)