Skip to content

Commit f82a0ad

Browse files
authored
Fix GPU support for JAX (#1179)
* Fix GPU support for JAX Added also a test to prevent regression. http://b/239603020 * Rename var in test
1 parent 029dea1 commit f82a0ad

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
123123

124124
# Install JAX
125125
{{ if eq .Accelerator "gpu" }}
126-
RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
126+
RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
127127
/tmp/clean-layer.sh
128128
{{ else }}
129129
RUN pip install jax[cpu] && \

tests/test_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import time
55

6+
import jax
67
import jax.numpy as np
78

89
from common import gpu_test
@@ -21,4 +22,4 @@ def test_grad(self):
2122

2223
def test_backend(self):
2324
expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu'
24-
25+
self.assertEqual(expected_backend, jax.default_backend())

0 commit comments

Comments
 (0)