Skip to content

Commit dba71da

Browse files
update JAX GPU version (#21293)
* test JAX version * update jax version * update jax version * update tensorflow version * keep tensorflow 2.18.1 * add comment
1 parent e19a2ed commit dba71da

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ source venv/bin/activate
1313
python --version
1414
python3 --version
1515

16-
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
16+
# setting the LD_LIBRARY_PATH manually is causing segmentation fault
17+
#export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
1718
# Check cuda
1819
nvidia-smi
1920
nvcc --version

requirements-jax-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch==2.6.0
88

99
# Jax with cuda support.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11-
jax[cuda12]==0.4.28
11+
jax[cuda12]==0.6.0
1212
flax
1313

1414
-r requirements-common.txt

0 commit comments

Comments
 (0)