Skip to content

Commit eb3ff29

Browse files
committed
Add split test
1 parent dbc95e4 commit eb3ff29

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.sharedvalue import SharedVariable, shared
1313
from pytensor.configdefaults import config
1414
from pytensor.graph import RewriteDatabaseQuery
15-
from pytensor.graph.basic import Apply
15+
from pytensor.graph.basic import Apply, Constant
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.op import Op
1818
from pytensor.ifelse import ifelse
@@ -38,6 +38,11 @@
3838
py_mode = Mode(linker="py", optimizer=None)
3939

4040

41+
def set_test_value(x, v):
42+
x.tag.test_value = v
43+
return x
44+
45+
4146
def compare_pytorch_and_py(
4247
fgraph: FunctionGraph,
4348
test_inputs: Iterable,
@@ -471,3 +476,66 @@ def test_ScalarLoop_Elemwise_multi_carries():
471476
compare_pytorch_and_py(
472477
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
473478
)
479+
480+
481+
rng = np.random.default_rng(42849)
482+
483+
484+
@pytest.mark.parametrize(
485+
"n_splits, axis, values, sizes",
486+
[
487+
(
488+
0,
489+
0,
490+
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)),
491+
set_test_value(pt.vector(dtype="int64"), []),
492+
),
493+
(
494+
5,
495+
0,
496+
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)),
497+
set_test_value(
498+
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
499+
),
500+
),
501+
(
502+
5,
503+
0,
504+
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)),
505+
set_test_value(
506+
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
507+
),
508+
),
509+
(
510+
5,
511+
-1,
512+
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
513+
set_test_value(
514+
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
515+
),
516+
),
517+
(
518+
5,
519+
-2,
520+
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
521+
set_test_value(
522+
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
523+
),
524+
),
525+
],
526+
)
527+
def test_Split(n_splits, axis, values, sizes):
528+
g = pt.split(values, sizes, n_splits, axis=axis)
529+
assert len(g) == n_splits
530+
if n_splits == 0:
531+
return
532+
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)
533+
534+
compare_pytorch_and_py(
535+
g_fg,
536+
[
537+
i.tag.test_value
538+
for i in g_fg.inputs
539+
if not isinstance(i, SharedVariable | Constant)
540+
],
541+
)

0 commit comments

Comments
 (0)