File tree Expand file tree Collapse file tree 2 files changed +15
-4
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Original file line number Diff line number Diff line change 1
1
from jax import numpy as jnp
2
2
3
3
from pytensor .link .jax .dispatch import jax_funcify
4
- from pytensor .tensor .sort import SortOp
4
+ from pytensor .tensor .sort import ArgSortOp , SortOp
5
5
6
6
7
7
@jax_funcify .register (SortOp )
@@ -12,3 +12,13 @@ def sort(arr, axis):
12
12
return jnp .sort (arr , axis = axis , stable = stable )
13
13
14
14
return sort
15
+
16
+
17
+ @jax_funcify .register (ArgSortOp )
18
+ def jax_funcify_ArgSort (op , ** kwargs ):
19
+ stable = op .kind == "stable"
20
+
21
+ def argsort (arr , axis ):
22
+ return jnp .argsort (arr , axis = axis , stable = stable )
23
+
24
+ return argsort
Original file line number Diff line number Diff line change 3
3
4
4
from pytensor .graph import FunctionGraph
5
5
from pytensor .tensor import matrix
6
- from pytensor .tensor .sort import sort
6
+ from pytensor .tensor .sort import argsort , sort
7
7
from tests .link .jax .test_basic import compare_jax_and_py
8
8
9
9
10
10
@pytest .mark .parametrize ("axis" , [None , - 1 ])
11
- def test_sort (axis ):
11
+ @pytest .mark .parametrize ("func" , (sort , argsort ))
12
+ def test_sort (func , axis ):
12
13
x = matrix ("x" , shape = (2 , 2 ), dtype = "float64" )
13
- out = sort (x , axis = axis )
14
+ out = func (x , axis = axis )
14
15
fgraph = FunctionGraph ([x ], [out ])
15
16
arr = np .array ([[1.0 , 4.0 ], [5.0 , 2.0 ]])
16
17
compare_jax_and_py (fgraph , [arr ])
You can’t perform that action at this time.
0 commit comments