File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
123
123
124
124
# Install JAX
125
125
{{ if eq .Accelerator "gpu" }}
126
- RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases .html && \
126
+ RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases .html && \
127
127
/tmp/clean-layer.sh
128
128
{{ else }}
129
129
RUN pip install jax[cpu] && \
Original file line number Diff line number Diff line change 3
3
import os
4
4
import time
5
5
6
+ import jax
6
7
import jax .numpy as np
7
8
8
9
from common import gpu_test
@@ -21,4 +22,4 @@ def test_grad(self):
21
22
22
23
def test_backend (self ):
23
24
expected_backend = 'cpu' if len (os .environ .get ('CUDA_VERSION' , '' )) == 0 else 'gpu'
24
-
25
+ self . assertEqual ( expected_backend , jax . default_backend ())
You can’t perform that action at this time.
0 commit comments