Skip to content

Commit 35f4706

Browse files
committed
Move JAX sort dispatch to its own module
1 parent 2307d87 commit 35f4706

File tree

6 files changed

+29
-18
lines changed

6 files changed

+29
-18
lines changed

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.sort
1718

1819
# isort: on

pytensor/link/jax/dispatch/sort.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from jax import numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.sort import SortOp
5+
6+
7+
@jax_funcify.register(SortOp)
8+
def jax_funcify_Sort(op, **kwargs):
9+
def sort(arr, axis):
10+
return jnp.sort(arr, axis=axis)
11+
12+
return sort

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
25-
from pytensor.tensor.sort import SortOp
2625

2726

2827
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
@@ -206,11 +205,3 @@ def tri(*args):
206205
return jnp.tri(*args, dtype=op.dtype)
207206

208207
return tri
209-
210-
211-
@jax_funcify.register(SortOp)
212-
def jax_funcify_Sort(op, **kwargs):
213-
def sort(arr, axis):
214-
return jnp.sort(arr, axis=axis)
215-
216-
return sort

tests/link/jax/__init__.py

Whitespace-only changes.

tests/link/jax/test_sort.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.graph import FunctionGraph
5+
from pytensor.tensor import matrix
6+
from pytensor.tensor.sort import sort
7+
from tests.link.jax.test_basic import compare_jax_and_py
8+
9+
10+
@pytest.mark.parametrize("axis", [None, -1])
11+
def test_sort(axis):
12+
x = matrix("x", shape=(2, 2), dtype="float64")
13+
out = sort(x, axis=axis)
14+
fgraph = FunctionGraph([x], [out])
15+
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
16+
compare_jax_and_py(fgraph, [arr])

tests/link/jax/test_tensor_basic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,6 @@ def test_tri():
218218
compare_jax_and_py(fgraph, [])
219219

220220

221-
@pytest.mark.parametrize("axis", [None, -1])
222-
def test_sort(axis):
223-
x = matrix("x", shape=(2, 2), dtype="float64")
224-
out = pytensor.tensor.sort(x, axis=axis)
225-
fgraph = FunctionGraph([x], [out])
226-
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
227-
compare_jax_and_py(fgraph, [arr])
228-
229-
230221
def test_tri_nonconcrete():
231222
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
232223

0 commit comments

Comments
 (0)