File tree Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -425,7 +425,30 @@ def test_ScalarLoop_while():
425
425
np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
426
426
427
427
428
- def test_ScalarLoop_Elemwise ():
428
+ def test_ScalarLoop_Elemwise_single_carries ():
429
+ n_steps = int64 ("n_steps" )
430
+ x0 = float64 ("x0" )
431
+ x = x0 * 2
432
+ until = x >= 10
433
+
434
+ scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
435
+ op = Elemwise (scalarop )
436
+
437
+ n_steps = pt .scalar ("n_steps" , dtype = "int32" )
438
+ x0 = pt .vector ("x0" , dtype = "float32" )
439
+ state , done = op (n_steps , x0 )
440
+
441
+ f = FunctionGraph ([n_steps , x0 ], [state , done ])
442
+ args = [
443
+ np .array (10 ).astype ("int32" ),
444
+ np .arange (0 , 5 ).astype ("float32" ),
445
+ ]
446
+ compare_pytorch_and_py (
447
+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
448
+ )
449
+
450
+
451
+ def test_ScalarLoop_Elemwise_multi_carries ():
429
452
n_steps = int64 ("n_steps" )
430
453
x0 = float64 ("x0" )
431
454
x1 = float64 ("x1" )
You can’t perform that action at this time.
0 commit comments