Skip to content

Commit b28a3a7

Browse files
committed
Only run required rewrites in JAX and PyTorch tests
Only run required rewrites in JAX tests Several tests ended up not testing the backend Op implementations due to constant folding of inputs.
1 parent 171bb8a commit b28a3a7

File tree

10 files changed

+69
-76
lines changed

10 files changed

+69
-76
lines changed

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def reshape(x, shape):
1515
@pytorch_funcify.register(Shape)
1616
def pytorch_funcify_Shape(op, **kwargs):
1717
def shape(x):
18-
return x.shape
18+
return torch.tensor(x.shape)
1919

2020
return shape
2121

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ def subtensor(x, *flattened_indices):
3434

3535
@pytorch_funcify.register(MakeSlice)
3636
def pytorch_funcify_makeslice(op, **kwargs):
37-
def makeslice(*x):
38-
return slice(x)
37+
def makeslice(start, stop, step):
38+
# Torch does not like numpy integers in indexing slices
39+
return slice(
40+
None if start is None else int(start),
41+
None if stop is None else int(stop),
42+
None if step is None else int(step),
43+
)
3944

4045
return makeslice
4146

tests/link/jax/test_basic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from pytensor.compile.builders import OpFromGraph
88
from pytensor.compile.function import function
9-
from pytensor.compile.mode import get_mode
9+
from pytensor.compile.mode import JAX, Mode
1010
from pytensor.compile.sharedvalue import SharedVariable, shared
1111
from pytensor.configdefaults import config
12+
from pytensor.graph import RewriteDatabaseQuery
1213
from pytensor.graph.basic import Apply
1314
from pytensor.graph.fg import FunctionGraph
1415
from pytensor.graph.op import Op, get_test_value
1516
from pytensor.ifelse import ifelse
17+
from pytensor.link.jax import JAXLinker
1618
from pytensor.raise_op import assert_op
1719
from pytensor.tensor.type import dscalar, matrices, scalar, vector
1820

@@ -26,9 +28,9 @@ def set_pytensor_flags():
2628
jax = pytest.importorskip("jax")
2729

2830

29-
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
30-
jax_mode = get_mode("JAX")
31-
py_mode = get_mode("FAST_COMPILE")
31+
optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude)
32+
jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer)
33+
py_mode = Mode(linker="py", optimizer=None)
3234

3335

3436
def compare_jax_and_py(

tests/link/jax/test_einsum.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
22
import pytest
33

4-
import pytensor
54
import pytensor.tensor as pt
5+
from pytensor.graph import FunctionGraph
6+
from tests.link.jax.test_basic import compare_jax_and_py
67

78

89
jax = pytest.importorskip("jax")
@@ -19,9 +20,8 @@ def test_jax_einsum():
1920
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
2021
)
2122
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
22-
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
23-
24-
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
23+
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
24+
compare_jax_and_py(fg, [x, y, z])
2525

2626

2727
@pytest.mark.xfail(raises=NotImplementedError)
@@ -33,6 +33,5 @@ def test_ellipsis_einsum():
3333
x_pt = pt.tensor("x", shape=x.shape)
3434
y_pt = pt.tensor("y", shape=y.shape)
3535
out = pt.einsum(subscripts, x_pt, y_pt)
36-
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
37-
38-
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
36+
fg = FunctionGraph([x_pt, y_pt], [out])
37+
compare_jax_and_py(fg, [x, y])

tests/link/jax/test_extra_ops.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,52 @@
11
import numpy as np
22
import pytest
3-
from packaging.version import parse as version_parse
43

54
import pytensor.tensor.basic as ptb
65
from pytensor.configdefaults import config
76
from pytensor.graph.fg import FunctionGraph
87
from pytensor.graph.op import get_test_value
98
from pytensor.tensor import extra_ops as pt_extra_ops
10-
from pytensor.tensor.type import matrix
9+
from pytensor.tensor.type import matrix, tensor
1110
from tests.link.jax.test_basic import compare_jax_and_py
1211

1312

1413
jax = pytest.importorskip("jax")
1514

1615

17-
def set_test_value(x, v):
18-
x.tag.test_value = v
19-
return x
20-
21-
2216
def test_extra_ops():
2317
a = matrix("a")
24-
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
18+
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
2519

2620
out = pt_extra_ops.cumsum(a, axis=0)
2721
fgraph = FunctionGraph([a], [out])
28-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
22+
compare_jax_and_py(fgraph, [a_test])
2923

3024
out = pt_extra_ops.cumprod(a, axis=1)
3125
fgraph = FunctionGraph([a], [out])
32-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
26+
compare_jax_and_py(fgraph, [a_test])
3327

3428
out = pt_extra_ops.diff(a, n=2, axis=1)
3529
fgraph = FunctionGraph([a], [out])
36-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
30+
compare_jax_and_py(fgraph, [a_test])
3731

3832
out = pt_extra_ops.repeat(a, (3, 3), axis=1)
3933
fgraph = FunctionGraph([a], [out])
40-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
34+
compare_jax_and_py(fgraph, [a_test])
4135

4236
c = ptb.as_tensor(5)
43-
4437
out = pt_extra_ops.fill_diagonal(a, c)
4538
fgraph = FunctionGraph([a], [out])
46-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
39+
compare_jax_and_py(fgraph, [a_test])
4740

4841
with pytest.raises(NotImplementedError):
4942
out = pt_extra_ops.fill_diagonal_offset(a, c, c)
5043
fgraph = FunctionGraph([a], [out])
51-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
44+
compare_jax_and_py(fgraph, [a_test])
5245

5346
with pytest.raises(NotImplementedError):
5447
out = pt_extra_ops.Unique(axis=1)(a)
5548
fgraph = FunctionGraph([a], [out])
56-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
49+
compare_jax_and_py(fgraph, [a_test])
5750

5851
indices = np.arange(np.prod((3, 4)))
5952
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
@@ -63,40 +56,30 @@ def test_extra_ops():
6356
)
6457

6558

66-
@pytest.mark.xfail(
67-
version_parse(jax.__version__) >= version_parse("0.2.12"),
68-
reason="JAX Numpy API does not support dynamic shapes",
69-
)
70-
def test_extra_ops_dynamic_shapes():
71-
a = matrix("a")
72-
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
73-
74-
# This function also cannot take symbolic input.
75-
c = ptb.as_tensor(5)
59+
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
60+
def test_bartlett_dynamic_shape():
61+
c = tensor(shape=(), dtype=int)
7662
out = pt_extra_ops.bartlett(c)
7763
fgraph = FunctionGraph([], [out])
78-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
64+
compare_jax_and_py(fgraph, [np.array(5)])
7965

80-
multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
81-
out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4))
82-
fgraph = FunctionGraph([], [out])
83-
compare_jax_and_py(
84-
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
85-
)
8666

87-
# The inputs are "concrete", yet it still has problems?
88-
out = pt_extra_ops.Unique()(
89-
ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2)))
90-
)
67+
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
68+
def test_ravel_multi_index_dynamic_shape():
69+
x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
70+
71+
x = tensor(shape=(None,), dtype=int)
72+
y = tensor(shape=(None,), dtype=int)
73+
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
9174
fgraph = FunctionGraph([], [out])
92-
compare_jax_and_py(fgraph, [])
75+
compare_jax_and_py(fgraph, [x_test, y_test])
9376

9477

95-
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
96-
def test_unique_nonconcrete():
78+
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
79+
def test_unique_dynamic_shape():
9780
a = matrix("a")
98-
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
81+
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
9982

10083
out = pt_extra_ops.Unique()(a)
10184
fgraph = FunctionGraph([a], [out])
102-
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
85+
compare_jax_and_py(fgraph, [a_test])

tests/link/jax/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ def test_multinomial():
705705
n = np.array([10, 40])
706706
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
707707
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
708-
g_fn = compile_random_function([], g, mode=jax_mode)
708+
g_fn = compile_random_function([], g, mode="JAX")
709709
samples = g_fn()
710710
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
711711
np.testing.assert_allclose(

tests/link/jax/test_scan.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_scan_sit_sot(view):
3232
xs = xs[view]
3333
fg = FunctionGraph([x0], [xs])
3434
test_input_vals = [np.e]
35-
compare_jax_and_py(fg, test_input_vals)
35+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
3636

3737

3838
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
@@ -47,7 +47,7 @@ def test_scan_mit_sot(view):
4747
xs = xs[view]
4848
fg = FunctionGraph([x0], [xs])
4949
test_input_vals = [np.full((3,), np.e)]
50-
compare_jax_and_py(fg, test_input_vals)
50+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
5151

5252

5353
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
@@ -74,7 +74,7 @@ def step(xtm3, xtm1, ytm4, ytm2):
7474

7575
fg = FunctionGraph([x0, y0], [xs, ys])
7676
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
77-
compare_jax_and_py(fg, test_input_vals)
77+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
7878

7979

8080
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
@@ -283,7 +283,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
283283
gamma_val,
284284
delta_val,
285285
]
286-
compare_jax_and_py(out_fg, test_input_vals)
286+
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
287287

288288

289289
def test_scan_mitsot_with_nonseq():
@@ -316,7 +316,7 @@ def input_step_fn(y_tm1, y_tm3, a):
316316
out_fg = FunctionGraph([a_pt], [y_scan_pt])
317317

318318
test_input_vals = [np.array(10.0).astype(config.floatX)]
319-
compare_jax_and_py(out_fg, test_input_vals)
319+
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
320320

321321

322322
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
@@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func):
334334
non_sequences=[A],
335335
outputs_info=[x0],
336336
n_steps=n_steps,
337-
mode=get_mode("JAX"),
338337
)
339338

340339
x0_val = (
@@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func):
346345

347346
fg = FunctionGraph([x0, A], [xs])
348347
test_input_vals = [x0_val, A_val]
349-
compare_jax_and_py(fg, test_input_vals)
348+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
350349

351350

352351
def test_nd_scan_sit_sot_with_seq():
@@ -362,15 +361,14 @@ def test_nd_scan_sit_sot_with_seq():
362361
non_sequences=[A],
363362
sequences=[x],
364363
n_steps=n_steps,
365-
mode=get_mode("JAX"),
366364
)
367365

368366
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
369367
A_val = np.eye(k, dtype=config.floatX)
370368

371369
fg = FunctionGraph([x, A], [xs])
372370
test_input_vals = [x_val, A_val]
373-
compare_jax_and_py(fg, test_input_vals)
371+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
374372

375373

376374
def test_nd_scan_mit_sot():
@@ -384,7 +382,6 @@ def test_nd_scan_mit_sot():
384382
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
385383
non_sequences=[A, B],
386384
n_steps=10,
387-
mode=get_mode("JAX"),
388385
)
389386

390387
fg = FunctionGraph([x0, A, B], [xs])
@@ -393,7 +390,7 @@ def test_nd_scan_mit_sot():
393390
B_val = np.eye(3, dtype=config.floatX)
394391

395392
test_input_vals = [x0_val, A_val, B_val]
396-
compare_jax_and_py(fg, test_input_vals)
393+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
397394

398395

399396
def test_nd_scan_sit_sot_with_carry():
@@ -417,7 +414,7 @@ def step(x, A):
417414
A_val = np.eye(3, dtype=config.floatX)
418415

419416
test_input_vals = [x0_val, A_val]
420-
compare_jax_and_py(fg, test_input_vals)
417+
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
421418

422419

423420
def test_default_mode_excludes_incompatible_rewrites():
@@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites():
426423
B = matrix("B")
427424
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
428425
fg = FunctionGraph([A, B], [out])
429-
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
426+
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
430427

431428

432429
def test_dynamic_sequence_length():

tests/link/jax/test_sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
5151

5252
dot_pt = op(x_pt, y_pt)
5353
fgraph = FunctionGraph(inputs, [dot_pt])
54-
compare_jax_and_py(fgraph, test_values)
54+
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
5555

5656

5757
def test_sparse_dot_non_const_raises():

tests/link/jax/test_tensor_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_arange_of_shape():
7474
x = vector("x")
7575
out = ptb.arange(1, x.shape[-1], 2)
7676
fgraph = FunctionGraph([x], [out])
77-
compare_jax_and_py(fgraph, [np.zeros((5,))])
77+
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
7878

7979

8080
def test_arange_nonconcrete():

tests/link/pytorch/test_basic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
import pytensor.tensor.basic as ptb
88
from pytensor.compile.builders import OpFromGraph
99
from pytensor.compile.function import function
10-
from pytensor.compile.mode import get_mode
10+
from pytensor.compile.mode import PYTORCH, Mode
1111
from pytensor.compile.sharedvalue import SharedVariable, shared
1212
from pytensor.configdefaults import config
13+
from pytensor.graph import RewriteDatabaseQuery
1314
from pytensor.graph.basic import Apply
1415
from pytensor.graph.fg import FunctionGraph
1516
from pytensor.graph.op import Op
1617
from pytensor.ifelse import ifelse
18+
from pytensor.link.pytorch.linker import PytorchLinker
1719
from pytensor.raise_op import CheckAndRaise
1820
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1921
from pytensor.tensor.type import matrices, matrix, scalar, vector
@@ -22,8 +24,13 @@
2224
torch = pytest.importorskip("torch")
2325

2426

25-
pytorch_mode = get_mode("PYTORCH")
26-
py_mode = get_mode("FAST_COMPILE")
27+
optimizer = RewriteDatabaseQuery(
28+
# While we don't have a PyTorch implementation of Blockwise
29+
include=["local_useless_unbatched_blockwise"],
30+
exclude=PYTORCH._optimizer.exclude,
31+
)
32+
pytorch_mode = Mode(linker=PytorchLinker(), optimizer=optimizer)
33+
py_mode = Mode(linker="py", optimizer=None)
2734

2835

2936
def compare_pytorch_and_py(
@@ -220,7 +227,7 @@ def test_alloc_and_empty():
220227
assert res.dtype == torch.float32
221228

222229
v = vector("v", shape=(3,), dtype="float64")
223-
out = alloc(v, (dim0, dim1, 3))
230+
out = alloc(v, dim0, dim1, 3)
224231
compare_pytorch_and_py(
225232
FunctionGraph([v, dim1], [out]),
226233
[np.array([1, 2, 3]), np.array(7)],

0 commit comments

Comments
 (0)