Skip to content

Commit 4a88c17

Browse files
ricardoV94brandonwillard
authored andcommitted
Fix HalfNormal/HalfNormalRV parameterization
1 parent 7f2e3e7 commit 4a88c17

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

pymc3/distributions/continuous.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -786,18 +786,12 @@ def dist(cls, sigma=None, tau=None, sd=None, *args, **kwargs):
786786

787787
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
788788

789-
# sigma = sd = sigma = aet.as_tensor_variable(sigma)
790-
# tau = tau = aet.as_tensor_variable(tau)
791-
792-
# mean = aet.sqrt(2 / (np.pi * tau))
793-
# variance = (1.0 - 2 / np.pi) / tau
794-
795789
assert_negative_support(tau, "tau", "HalfNormal")
796790
assert_negative_support(sigma, "sigma", "HalfNormal")
797791

798-
return super().dist([sigma, tau], **kwargs)
792+
return super().dist([0.0, sigma], **kwargs)
799793

800-
def logp(value, sigma, tau):
794+
def logp(value, loc, sigma):
801795
"""
802796
Calculate log-probability of HalfNormal distribution at specified value.
803797
@@ -811,14 +805,16 @@ def logp(value, sigma, tau):
811805
-------
812806
TensorVariable
813807
"""
808+
tau, sigma = get_tau_sigma(tau=None, sigma=sigma)
809+
814810
return bound(
815-
-0.5 * tau * value ** 2 + 0.5 * aet.log(tau * 2.0 / np.pi),
816-
value >= 0,
811+
-0.5 * tau * (value - loc) ** 2 + 0.5 * aet.log(tau * 2.0 / np.pi),
812+
value >= loc,
817813
tau > 0,
818814
sigma > 0,
819815
)
820816

821-
def logcdf(value, sigma, tau):
817+
def logcdf(value, loc, sigma):
822818
"""
823819
Compute the log of the cumulative distribution function for HalfNormal distribution
824820
at the specified value.
@@ -833,10 +829,10 @@ def logcdf(value, sigma, tau):
833829
-------
834830
TensorVariable
835831
"""
836-
z = zvalue(value, mu=0, sigma=sigma)
832+
z = zvalue(value, mu=loc, sigma=sigma)
837833
return bound(
838834
aet.log1p(-aet.erfc(z / aet.sqrt(2.0))),
839-
0 <= value,
835+
loc <= value,
840836
0 < sigma,
841837
)
842838

0 commit comments

Comments
 (0)