1
1
from collections .abc import Callable , Iterable
2
2
from functools import partial
3
- from itertools import repeat , starmap
4
- from unittest .mock import MagicMock , call , patch
5
3
6
4
import numpy as np
7
5
import pytest
@@ -408,25 +406,11 @@ def test_ScalarLoop_while():
408
406
for res , expected in zip (
409
407
[fn (n_steps = 20 , x0 = 0 ), fn (n_steps = 20 , x0 = 1 ), fn (n_steps = 5 , x0 = 1 )],
410
408
[[10 , True ], [10 , True ], [6 , False ]],
409
+ strict = True ,
411
410
):
412
411
np .testing .assert_allclose (res [0 ], np .array (expected [0 ]))
413
412
np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
414
413
415
- def test_pytorch_OpFromGraph ():
416
- x , y , z = matrices ("xyz" )
417
- ofg_1 = OpFromGraph ([x , y ], [x + y ])
418
- ofg_2 = OpFromGraph ([x , y ], [x * y , x - y ])
419
-
420
- o1 , o2 = ofg_2 (y , z )
421
- out = ofg_1 (x , o1 ) + o2
422
-
423
- xv = np .ones ((2 , 2 ), dtype = config .floatX )
424
- yv = np .ones ((2 , 2 ), dtype = config .floatX ) * 3
425
- zv = np .ones ((2 , 2 ), dtype = config .floatX ) * 5
426
-
427
- f = FunctionGraph ([x , y , z ], [out ])
428
- compare_pytorch_and_py (f , [xv , yv , zv ])
429
-
430
414
431
415
def test_ScalarLoop_Elemwise ():
432
416
n_steps = int64 ("n_steps" )
@@ -444,26 +428,3 @@ def test_ScalarLoop_Elemwise():
444
428
f = FunctionGraph ([n_steps , x0 ], [state , done ])
445
429
args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
446
430
compare_pytorch_and_py (f , args )
447
-
448
-
449
- torch_elemwise = pytest .importorskip ("pytensor.link.pytorch.dispatch.elemwise" )
450
-
451
-
452
- @pytest .mark .parametrize ("input_shapes" , [[(5 , 1 , 1 , 8 ), (3 , 1 , 1 ), (8 ,)]])
453
- @patch ("pytensor.link.pytorch.dispatch.elemwise.Elemwise" )
454
- def test_ScalarLoop_Elemwise_iteration_logic (_ , input_shapes ):
455
- args = [torch .ones (* s ) for s in input_shapes [:- 1 ]] + [
456
- torch .zeros (* input_shapes [- 1 ])
457
- ]
458
- mock_inner_func = MagicMock ()
459
- ret_value = torch .rand (2 , 2 ).unbind (0 )
460
- mock_inner_func .f .return_value = ret_value
461
- elemwise_fn = torch_elemwise .elemwise_scalar_loop (mock_inner_func .f , None , None )
462
- result = elemwise_fn (* args )
463
- for actual , expected in zip (ret_value , result ):
464
- assert torch .all (torch .eq (* torch .broadcast_tensors (actual , expected )))
465
- np .testing .assert_equal (mock_inner_func .f .call_count , len (result [0 ]))
466
-
467
- expected_args = torch .FloatTensor ([1.0 ] * (len (input_shapes ) - 1 ) + [0.0 ]).unbind (0 )
468
- expected_calls = starmap (call , repeat (expected_args , mock_inner_func .f .call_count ))
469
- mock_inner_func .f .assert_has_calls (expected_calls )
0 commit comments