Skip to content

Commit 26ba673

Browse files
committed
Implement JAX dispatch for Argsort
1 parent 3dd1f80 commit 26ba673

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pytensor/link/jax/dispatch/sort.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from jax import numpy as jnp
22

33
from pytensor.link.jax.dispatch import jax_funcify
4-
from pytensor.tensor.sort import SortOp
4+
from pytensor.tensor.sort import ArgSortOp, SortOp
55

66

77
@jax_funcify.register(SortOp)
@@ -12,3 +12,13 @@ def sort(arr, axis):
1212
return jnp.sort(arr, axis=axis, stable=stable)
1313

1414
return sort
15+
16+
17+
@jax_funcify.register(ArgSortOp)
18+
def jax_funcify_ArgSort(op, **kwargs):
19+
stable = op.kind == "stable"
20+
21+
def argsort(arr, axis):
22+
return jnp.argsort(arr, axis=axis, stable=stable)
23+
24+
return argsort

tests/link/jax/test_sort.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
from pytensor.graph import FunctionGraph
55
from pytensor.tensor import matrix
6-
from pytensor.tensor.sort import sort
6+
from pytensor.tensor.sort import argsort, sort
77
from tests.link.jax.test_basic import compare_jax_and_py
88

99

1010
@pytest.mark.parametrize("axis", [None, -1])
11-
def test_sort(axis):
11+
@pytest.mark.parametrize("func", (sort, argsort))
12+
def test_sort(func, axis):
1213
x = matrix("x", shape=(2, 2), dtype="float64")
13-
out = sort(x, axis=axis)
14+
out = func(x, axis=axis)
1415
fgraph = FunctionGraph([x], [out])
1516
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
1617
compare_jax_and_py(fgraph, [arr])

0 commit comments

Comments
 (0)