Skip to content

Commit 5008194

Browse files
Remove tests for random variable samples shape and size
Most of the random variable logic has been moved to aesara, as well as most of the relative tests. More details can be found on issue #4554
1 parent a312231 commit 5008194

File tree

1 file changed

+0
-111
lines changed

1 file changed

+0
-111
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
250250
default_shape = (1,)
251251

252252

253-
@pytest.mark.skip(reason="This test is covered by Aesara")
254-
class TestNormal(BaseTestCases.BaseTestCase):
255-
distribution = pm.Normal
256-
params = {"mu": 0.0, "tau": 1.0}
257-
258-
259253
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
260254
class TestTruncatedNormal(BaseTestCases.BaseTestCase):
261255
distribution = pm.TruncatedNormal
@@ -280,18 +274,6 @@ class TestSkewNormal(BaseTestCases.BaseTestCase):
280274
params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
281275

282276

283-
@pytest.mark.skip(reason="This test is covered by Aesara")
284-
class TestHalfNormal(BaseTestCases.BaseTestCase):
285-
distribution = pm.HalfNormal
286-
params = {"tau": 1.0}
287-
288-
289-
@pytest.mark.skip(reason="This test is covered by Aesara")
290-
class TestUniform(BaseTestCases.BaseTestCase):
291-
distribution = pm.Uniform
292-
params = {"lower": 0.0, "upper": 1.0}
293-
294-
295277
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
296278
class TestTriangular(BaseTestCases.BaseTestCase):
297279
distribution = pm.Triangular
@@ -315,12 +297,6 @@ class TestKumaraswamy(BaseTestCases.BaseTestCase):
315297
params = {"a": 1.0, "b": 1.0}
316298

317299

318-
@pytest.mark.skip(reason="This test is covered by Aesara")
319-
class TestExponential(BaseTestCases.BaseTestCase):
320-
distribution = pm.Exponential
321-
params = {"lam": 1.0}
322-
323-
324300
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
325301
class TestLaplace(BaseTestCases.BaseTestCase):
326302
distribution = pm.Laplace
@@ -351,30 +327,6 @@ class TestPareto(BaseTestCases.BaseTestCase):
351327
params = {"alpha": 0.5, "m": 1.0}
352328

353329

354-
@pytest.mark.skip(reason="This test is covered by Aesara")
355-
class TestCauchy(BaseTestCases.BaseTestCase):
356-
distribution = pm.Cauchy
357-
params = {"alpha": 1.0, "beta": 1.0}
358-
359-
360-
@pytest.mark.skip(reason="This test is covered by Aesara")
361-
class TestHalfCauchy(BaseTestCases.BaseTestCase):
362-
distribution = pm.HalfCauchy
363-
params = {"beta": 1.0}
364-
365-
366-
@pytest.mark.skip(reason="This test is covered by Aesara")
367-
class TestGamma(BaseTestCases.BaseTestCase):
368-
distribution = pm.Gamma
369-
params = {"alpha": 1.0, "beta": 1.0}
370-
371-
372-
@pytest.mark.skip(reason="This test is covered by Aesara")
373-
class TestInverseGamma(BaseTestCases.BaseTestCase):
374-
distribution = pm.InverseGamma
375-
params = {"alpha": 0.5, "beta": 0.5}
376-
377-
378330
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
379331
class TestChiSquared(BaseTestCases.BaseTestCase):
380332
distribution = pm.ChiSquared
@@ -417,42 +369,18 @@ class TestLogitNormal(BaseTestCases.BaseTestCase):
417369
params = {"mu": 0.0, "sigma": 1.0}
418370

419371

420-
@pytest.mark.skip(reason="This test is covered by Aesara")
421-
class TestBinomial(BaseTestCases.BaseTestCase):
422-
distribution = pm.Binomial
423-
params = {"n": 5, "p": 0.5}
424-
425-
426372
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
427373
class TestBetaBinomial(BaseTestCases.BaseTestCase):
428374
distribution = pm.BetaBinomial
429375
params = {"n": 5, "alpha": 1.0, "beta": 1.0}
430376

431377

432-
@pytest.mark.skip(reason="This test is covered by Aesara")
433-
class TestBernoulli(BaseTestCases.BaseTestCase):
434-
distribution = pm.Bernoulli
435-
params = {"p": 0.5}
436-
437-
438378
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
439379
class TestDiscreteWeibull(BaseTestCases.BaseTestCase):
440380
distribution = pm.DiscreteWeibull
441381
params = {"q": 0.25, "beta": 2.0}
442382

443383

444-
@pytest.mark.skip(reason="This test is covered by Aesara")
445-
class TestPoisson(BaseTestCases.BaseTestCase):
446-
distribution = pm.Poisson
447-
params = {"mu": 1.0}
448-
449-
450-
@pytest.mark.skip(reason="This test is covered by Aesara")
451-
class TestNegativeBinomial(BaseTestCases.BaseTestCase):
452-
distribution = pm.NegativeBinomial
453-
params = {"mu": 1.0, "alpha": 1.0}
454-
455-
456384
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
457385
class TestConstant(BaseTestCases.BaseTestCase):
458386
distribution = pm.Constant
@@ -501,45 +429,6 @@ class TestMoyal(BaseTestCases.BaseTestCase):
501429
params = {"mu": 0.0, "sigma": 1.0}
502430

503431

504-
@pytest.mark.skip(reason="This test is covered by Aesara")
505-
class TestCategorical(BaseTestCases.BaseTestCase):
506-
distribution = pm.Categorical
507-
params = {"p": np.ones(BaseTestCases.BaseTestCase.shape)}
508-
509-
def get_random_variable(
510-
self, shape, with_vector_params=False, **kwargs
511-
): # don't transform categories
512-
return super().get_random_variable(shape, with_vector_params=False, **kwargs)
513-
514-
def test_probability_vector_shape(self):
515-
"""Check that if a 2d array of probabilities are passed to categorical correct shape is returned"""
516-
p = np.ones((10, 5))
517-
assert pm.Categorical.dist(p=p).random().shape == (10,)
518-
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 10)
519-
p = np.ones((3, 7, 5))
520-
assert pm.Categorical.dist(p=p).random().shape == (3, 7)
521-
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7)
522-
523-
524-
@pytest.mark.skip(reason="This test is covered by Aesara")
525-
class TestDirichlet(SeededTest):
526-
@pytest.mark.parametrize(
527-
"shape, size",
528-
[
529-
((2), (1)),
530-
((2), (2)),
531-
((2, 2), (2, 100)),
532-
((3, 4), (3, 4)),
533-
((3, 4), (3, 4, 100)),
534-
((3, 4), (100)),
535-
((3, 4), (1)),
536-
],
537-
)
538-
def test_dirichlet_random_shape(self, shape, size):
539-
out_shape = to_tuple(size) + to_tuple(shape)
540-
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
541-
542-
543432
class TestCorrectParametrizationMappingPymcToScipy(SeededTest):
544433
@staticmethod
545434
def get_inputs_from_apply_node_outputs(outputs):

0 commit comments

Comments
 (0)