Skip to content

Commit 9e8db9d

Browse files
authored
Use CUDA 11.0 compatible JAX version. (#1124)
- Added additional smoke tests for jax and flax. http://215555626
1 parent f0ed644 commit 9e8db9d

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
114114

115115
# Install JAX
116116
{{ if eq .Accelerator "gpu" }}
117-
RUN pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
117+
RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
118118
/tmp/clean-layer.sh
119119
{{ else }}
120120
RUN pip install jax[cpu] && \

tests/test_flax.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,52 @@
11
import unittest
22

3+
import jax
34
import jax.numpy as jnp
45
import numpy as np
6+
import optax
57

68
from flax import linen as nn
9+
from flax.training import train_state
710

811

912
class TestFlax(unittest.TestCase):
1013

11-
def test_bla(self):
14+
def test_pooling(self):
1215
x = jnp.full((1, 3, 3, 1), 2.)
1316
mul_reduce = lambda x, y: x * y
1417
y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
1518
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))
19+
20+
def test_cnn(self):
21+
class CNN(nn.Module):
22+
@nn.compact
23+
def __call__(self, x):
24+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
25+
x = nn.relu(x)
26+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
27+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
28+
x = nn.relu(x)
29+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
30+
x = x.reshape((x.shape[0], -1))
31+
x = nn.Dense(features=256)(x)
32+
x = nn.relu(x)
33+
x = nn.Dense(features=120)(x)
34+
x = nn.log_softmax(x)
35+
return x
36+
37+
def create_train_state(rng, learning_rate, momentum):
38+
cnn = CNN()
39+
params = cnn.init(rng, jnp.ones([1, 224, 224, 3]))['params']
40+
tx = optax.sgd(learning_rate, momentum)
41+
return train_state.TrainState.create(
42+
apply_fn=cnn.apply, params=params, tx=tx)
43+
44+
rng = jax.random.PRNGKey(0)
45+
rng, init_rng = jax.random.split(rng)
46+
47+
learning_rate = 2e-5
48+
momentum = 0.9
49+
state = create_train_state(init_rng, learning_rate, momentum)
50+
self.assertEqual(0, state.step)
51+
52+

tests/test_jax.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import unittest
2+
3+
import os
24
import time
35

46
import jax.numpy as np
@@ -16,3 +18,7 @@ def test_grad(self):
1618
grad_tanh = grad(self.tanh)
1719
ag = grad_tanh(1.0)
1820
self.assertEqual(0.4199743, ag)
21+
22+
def test_backend(self):
23+
expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu'
24+

0 commit comments

Comments
 (0)