|
12 | 12 | from pytensor.compile.sharedvalue import SharedVariable, shared
|
13 | 13 | from pytensor.configdefaults import config
|
14 | 14 | from pytensor.graph import RewriteDatabaseQuery
|
15 |
| -from pytensor.graph.basic import Apply |
| 15 | +from pytensor.graph.basic import Apply, Constant |
16 | 16 | from pytensor.graph.fg import FunctionGraph
|
17 | 17 | from pytensor.graph.op import Op
|
18 | 18 | from pytensor.ifelse import ifelse
|
|
38 | 38 | py_mode = Mode(linker="py", optimizer=None)
|
39 | 39 |
|
40 | 40 |
|
| 41 | +def set_test_value(x, v): |
| 42 | + x.tag.test_value = v |
| 43 | + return x |
| 44 | + |
| 45 | + |
41 | 46 | def compare_pytorch_and_py(
|
42 | 47 | fgraph: FunctionGraph,
|
43 | 48 | test_inputs: Iterable,
|
@@ -471,3 +476,66 @@ def test_ScalarLoop_Elemwise_multi_carries():
|
471 | 476 | compare_pytorch_and_py(
|
472 | 477 | f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
|
473 | 478 | )
|
| 479 | + |
| 480 | + |
| 481 | +rng = np.random.default_rng(42849) |
| 482 | + |
| 483 | + |
| 484 | +@pytest.mark.parametrize( |
| 485 | + "n_splits, axis, values, sizes", |
| 486 | + [ |
| 487 | + ( |
| 488 | + 0, |
| 489 | + 0, |
| 490 | + set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)), |
| 491 | + set_test_value(pt.vector(dtype="int64"), []), |
| 492 | + ), |
| 493 | + ( |
| 494 | + 5, |
| 495 | + 0, |
| 496 | + set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)), |
| 497 | + set_test_value( |
| 498 | + pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5) |
| 499 | + ), |
| 500 | + ), |
| 501 | + ( |
| 502 | + 5, |
| 503 | + 0, |
| 504 | + set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)), |
| 505 | + set_test_value( |
| 506 | + pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5) |
| 507 | + ), |
| 508 | + ), |
| 509 | + ( |
| 510 | + 5, |
| 511 | + -1, |
| 512 | + set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), |
| 513 | + set_test_value( |
| 514 | + pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5) |
| 515 | + ), |
| 516 | + ), |
| 517 | + ( |
| 518 | + 5, |
| 519 | + -2, |
| 520 | + set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), |
| 521 | + set_test_value( |
| 522 | + pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5) |
| 523 | + ), |
| 524 | + ), |
| 525 | + ], |
| 526 | +) |
| 527 | +def test_Split(n_splits, axis, values, sizes): |
| 528 | + g = pt.split(values, sizes, n_splits, axis=axis) |
| 529 | + assert len(g) == n_splits |
| 530 | + if n_splits == 0: |
| 531 | + return |
| 532 | + g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g) |
| 533 | + |
| 534 | + compare_pytorch_and_py( |
| 535 | + g_fg, |
| 536 | + [ |
| 537 | + i.tag.test_value |
| 538 | + for i in g_fg.inputs |
| 539 | + if not isinstance(i, SharedVariable | Constant) |
| 540 | + ], |
| 541 | + ) |
0 commit comments