Skip to content

Commit ea7afba

Browse files
committed
Refactor Constant
1 parent 03e1df5 commit ea7afba

File tree

4 files changed

+50
-58
lines changed

4 files changed

+50
-58
lines changed

pymc3/distributions/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
Binomial,
6464
Categorical,
6565
Constant,
66-
ConstantDist,
6766
DiscreteUniform,
6867
DiscreteWeibull,
6968
Geometric,
@@ -138,7 +137,6 @@
138137
"Bernoulli",
139138
"Poisson",
140139
"NegativeBinomial",
141-
"ConstantDist",
142140
"Constant",
143141
"ZeroInflatedPoisson",
144142
"ZeroInflatedNegativeBinomial",

pymc3/distributions/discrete.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import warnings
15-
1614
import aesara.tensor as at
1715
import numpy as np
1816

@@ -51,7 +49,6 @@
5149
"DiscreteWeibull",
5250
"Poisson",
5351
"NegativeBinomial",
54-
"ConstantDist",
5552
"Constant",
5653
"ZeroInflatedPoisson",
5754
"ZeroInflatedBinomial",
@@ -1164,6 +1161,23 @@ def logp(value, p):
11641161
)
11651162

11661163

1164+
class ConstantRV(RandomVariable):
1165+
name = "constant"
1166+
ndim_supp = 0
1167+
ndims_params = [0]
1168+
dtype = "floatX" # Should be treated as a discrete variable!
1169+
_print_name = ("Constant", "\\operatorname{Constant}")
1170+
1171+
@classmethod
1172+
def rng_fn(cls, rng, c, size=None):
1173+
if size is None:
1174+
return c.copy()
1175+
return np.full(size, c)
1176+
1177+
1178+
constant = ConstantRV()
1179+
1180+
11671181
class Constant(Discrete):
11681182
r"""
11691183
Constant log-likelihood.
@@ -1174,40 +1188,14 @@ class Constant(Discrete):
11741188
Constant parameter.
11751189
"""
11761190

1177-
def __init__(self, c, *args, **kwargs):
1178-
warnings.warn(
1179-
"Constant has been deprecated. We recommend using a Deterministic object instead.",
1180-
DeprecationWarning,
1181-
)
1182-
super().__init__(*args, **kwargs)
1183-
self.mean = self.median = self.mode = self.c = c = at.as_tensor_variable(c)
1184-
1185-
def random(self, point=None, size=None):
1186-
r"""
1187-
Draw random values from Constant distribution.
1188-
1189-
Parameters
1190-
----------
1191-
point: dict, optional
1192-
Dict of variable values on which random values are to be
1193-
conditioned (uses default point if not specified).
1194-
size: int, optional
1195-
Desired size of random sample (returns one sample if not
1196-
specified).
1191+
rv_op = constant
11971192

1198-
Returns
1199-
-------
1200-
array
1201-
"""
1202-
# c = draw_values([self.c], point=point, size=size)[0]
1203-
# dtype = np.array(c).dtype
1204-
#
1205-
# def _random(c, dtype=dtype, size=None):
1206-
# return np.full(size, fill_value=c, dtype=dtype)
1207-
#
1208-
# return generate_samples(_random, c=c, dist_shape=self.shape, size=size).astype(dtype)
1193+
@classmethod
1194+
def dist(cls, c, *args, **kwargs):
1195+
c = at.as_tensor_variable(floatX(c))
1196+
return super().dist([c], **kwargs)
12091197

1210-
def logp(self, value):
1198+
def logp(value, c):
12111199
r"""
12121200
Calculate log-probability of Constant distribution at specified value.
12131201
@@ -1221,11 +1209,10 @@ def logp(self, value):
12211209
-------
12221210
TensorVariable
12231211
"""
1224-
c = self.c
1225-
return bound(0, at.eq(value, c))
1226-
1227-
1228-
ConstantDist = Constant
1212+
return bound(
1213+
at.zeros_like(value),
1214+
at.eq(value, c),
1215+
)
12291216

12301217

12311218
class ZeroInflatedPoisson(Discrete):

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,6 @@ def test_bound_poisson(self):
16141614
x = NonZeroPoisson("x", mu=4)
16151615
assert np.isinf(logpt(x, 0).eval())
16161616

1617-
@pytest.mark.xfail(reason="Distribution not refactored yet")
16181617
def test_constantdist(self):
16191618
self.check_logp(Constant, I, {"c": I}, lambda value, c: np.log(c == value))
16201619

@@ -2820,13 +2819,17 @@ def test_issue_4499(self):
28202819
# Test for bug in Uniform and DiscreteUniform logp when setting check_bounds = False
28212820
# https://github.com/pymc-devs/pymc3/issues/4499
28222821
with pm.Model(check_bounds=False) as m:
2823-
x = pm.Uniform("x", 0, 2, shape=10, transform=None)
2822+
x = pm.Uniform("x", 0, 2, size=10, transform=None)
28242823
assert_almost_equal(m.logp({"x": np.ones(10)}), -np.log(2) * 10)
28252824

28262825
with pm.Model(check_bounds=False) as m:
28272826
x = pm.DiscreteUniform("x", 0, 1, size=10)
28282827
assert_almost_equal(m.logp({"x": np.ones(10)}), -np.log(2) * 10)
28292828

2829+
with pm.Model(check_bounds=False) as m:
2830+
x = pm.Constant("x", 1, size=10)
2831+
assert_almost_equal(m.logp({"x": np.ones(10)}), 0 * 10)
2832+
28302833

28312834
@pytest.mark.xfail(reason="DensityDist no longer supported")
28322835
def test_serialize_density_dist():

pymc3/tests/test_distributions_random.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from pymc3.tests.helpers import SeededTest, select_by_precision
4141
from pymc3.tests.test_distributions import (
4242
Domain,
43-
I,
4443
Nat,
4544
PdMatrix,
4645
PdMatrixChol,
@@ -314,12 +313,6 @@ class TestLogitNormal(BaseTestCases.BaseTestCase):
314313
params = {"mu": 0.0, "sigma": 1.0}
315314

316315

317-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
318-
class TestConstant(BaseTestCases.BaseTestCase):
319-
distribution = pm.Constant
320-
params = {"c": 3}
321-
322-
323316
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
324317
class TestZeroInflatedPoisson(BaseTestCases.BaseTestCase):
325318
distribution = pm.ZeroInflatedPoisson
@@ -918,6 +911,24 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
918911
]
919912

920913

914+
class TestConstant(BaseTestDistribution):
915+
def constant_rng_fn(self, size, c):
916+
if size is None:
917+
return c
918+
return np.full(size, c)
919+
920+
pymc_dist = pm.Constant
921+
pymc_dist_params = {"c": 3}
922+
expected_rv_op_params = {"c": 3}
923+
reference_dist_params = {"c": 3}
924+
reference_dist = lambda self: self.constant_rng_fn
925+
tests_to_run = [
926+
"check_pymc_params_match_rv_op",
927+
"check_pymc_draws_match_reference",
928+
"check_rv_size",
929+
]
930+
931+
921932
class TestScalarParameterSamples(SeededTest):
922933
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
923934
def test_bounded(self):
@@ -1027,13 +1038,6 @@ def test_half_flat(self):
10271038
with pytest.raises(ValueError):
10281039
f.random(1)
10291040

1030-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1031-
def test_constant_dist(self):
1032-
def ref_rand(size, c):
1033-
return c * np.ones(size, dtype=int)
1034-
1035-
pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
1036-
10371041
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
10381042
def test_matrix_normal(self):
10391043
def ref_rand(size, mu, rowcov, colcov):

0 commit comments

Comments
 (0)