|
1 | 1 | import unittest
|
2 | 2 |
|
| 3 | +import jax |
3 | 4 | import jax.numpy as jnp
|
4 | 5 | import numpy as np
|
| 6 | +import optax |
5 | 7 |
|
6 | 8 | from flax import linen as nn
|
| 9 | +from flax.training import train_state |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class TestFlax(unittest.TestCase):
|
10 | 13 |
|
11 |
| - def test_bla(self): |
| 14 | + def test_pooling(self): |
12 | 15 | x = jnp.full((1, 3, 3, 1), 2.)
|
13 | 16 | mul_reduce = lambda x, y: x * y
|
14 | 17 | y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
|
15 | 18 | np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))
|
| 19 | + |
| 20 | + def test_cnn(self): |
| 21 | + class CNN(nn.Module): |
| 22 | + @nn.compact |
| 23 | + def __call__(self, x): |
| 24 | + x = nn.Conv(features=32, kernel_size=(3, 3))(x) |
| 25 | + x = nn.relu(x) |
| 26 | + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
| 27 | + x = nn.Conv(features=64, kernel_size=(3, 3))(x) |
| 28 | + x = nn.relu(x) |
| 29 | + x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) |
| 30 | + x = x.reshape((x.shape[0], -1)) |
| 31 | + x = nn.Dense(features=256)(x) |
| 32 | + x = nn.relu(x) |
| 33 | + x = nn.Dense(features=120)(x) |
| 34 | + x = nn.log_softmax(x) |
| 35 | + return x |
| 36 | + |
| 37 | + def create_train_state(rng, learning_rate, momentum): |
| 38 | + cnn = CNN() |
| 39 | + params = cnn.init(rng, jnp.ones([1, 224, 224, 3]))['params'] |
| 40 | + tx = optax.sgd(learning_rate, momentum) |
| 41 | + return train_state.TrainState.create( |
| 42 | + apply_fn=cnn.apply, params=params, tx=tx) |
| 43 | + |
| 44 | + rng = jax.random.PRNGKey(0) |
| 45 | + rng, init_rng = jax.random.split(rng) |
| 46 | + |
| 47 | + learning_rate = 2e-5 |
| 48 | + momentum = 0.9 |
| 49 | + state = create_train_state(init_rng, learning_rate, momentum) |
| 50 | + self.assertEqual(0, state.step) |
| 51 | + |
| 52 | + |
0 commit comments