@@ -835,8 +835,11 @@ class TestEulerMaruyama:
835
835
@pytest .mark .parametrize ("batched_param" , [1 , 2 ])
836
836
@pytest .mark .parametrize ("explicit_shape" , (True , False ))
837
837
def test_batched_size (self , explicit_shape , batched_param ):
838
+ RANDOM_SEED = 42
839
+ numpy_rng = np .random .default_rng (RANDOM_SEED )
840
+
838
841
steps , batch_size = 100 , 5
839
- param_val = np .square (np . random . randn (batch_size ))
842
+ param_val = np .square (numpy_rng . standard_normal (batch_size ))
840
843
if explicit_shape :
841
844
kwargs = {"shape" : (batch_size , steps )}
842
845
else :
@@ -853,9 +856,9 @@ def sde_fn(x, k, d, s):
853
856
"y" , dt = 0.02 , sde_fn = sde_fn , sde_pars = sde_pars , init_dist = init_dist , ** kwargs
854
857
)
855
858
856
- y_eval = draw (y , draws = 2 )
859
+ y_eval = draw (y , draws = 2 , random_seed = RANDOM_SEED )
857
860
assert y_eval [0 ].shape == (batch_size , steps )
858
- assert not np .any (np .isclose (y_eval [0 ], y_eval [1 ]))
861
+ assert np .any (~ np .isclose (y_eval [0 ], y_eval [1 ]))
859
862
860
863
if explicit_shape :
861
864
kwargs ["shape" ] = steps
@@ -873,7 +876,7 @@ def sde_fn(x, k, d, s):
873
876
** kwargs ,
874
877
)
875
878
876
- t0_init = t0 .initial_point ()
879
+ t0_init = t0 .initial_point (seed = RANDOM_SEED )
877
880
t1_init = {f"y_{ i } " : t0_init ["y" ][i ] for i in range (batch_size )}
878
881
np .testing .assert_allclose (
879
882
t0 .compile_logp ()(t0_init ),
0 commit comments