Skip to content

Commit 8b3de66

Browse files
authored
Upgrade JAX CUDA
http://b/281861396
1 parent 2c778df commit 8b3de66

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
149149

150150
# Install JAX
151151
{{ if eq .Accelerator "gpu" }}
152-
RUN pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
152+
RUN pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
153153
/tmp/clean-layer.sh
154154
{{ else }}
155155
RUN pip install jax[cpu] && \

0 commit comments

Comments
 (0)