Skip to content

Commit c5c48cc

Browse files
author
Joseph Hall
committed
Merge remote-tracking branch 'upstream/main' into feature/gp-cov-type-hints
2 parents 9bf270d + 146afc5 commit c5c48cc

File tree

11 files changed

+274
-42
lines changed

11 files changed

+274
-42
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ jobs:
413413
floatx: [float32]
414414
python-version: ["3.11"]
415415
test-subset:
416-
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py
416+
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py
417417
fail-fast: false
418418
runs-on: ${{ matrix.os }}
419419
env:

pymc/distributions/continuous.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pytensor.tensor.var import TensorConstant
5858

5959
from pymc.logprob.abstract import _logcdf_helper, _logprob_helper
60+
from pymc.logprob.basic import icdf
6061

6162
try:
6263
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
@@ -856,6 +857,11 @@ def logcdf(value, loc, sigma):
856857
msg="sigma > 0",
857858
)
858859

860+
def icdf(value, loc, sigma):
861+
res = icdf(Normal.dist(loc, sigma), (value + 1.0) / 2.0)
862+
res = check_icdf_value(res, value)
863+
return res
864+
859865

860866
class WaldRV(RandomVariable):
861867
name = "wald"
@@ -1714,12 +1720,17 @@ def logcdf(value, mu, sigma):
17141720
-np.inf,
17151721
normal_lcdf(mu, sigma, pt.log(value)),
17161722
)
1723+
17171724
return check_parameters(
17181725
res,
17191726
sigma > 0,
17201727
msg="sigma > 0",
17211728
)
17221729

1730+
def icdf(value, mu, sigma):
1731+
res = pt.exp(icdf(Normal.dist(mu, sigma), value))
1732+
return res
1733+
17231734

17241735
Lognormal = LogNormal
17251736

@@ -2121,6 +2132,15 @@ def logcdf(value, loc, beta):
21212132
msg="beta > 0",
21222133
)
21232134

2135+
def icdf(value, loc, beta):
2136+
res = loc + beta * pt.tan(np.pi * (value) / 2.0)
2137+
res = check_icdf_value(res, value)
2138+
return check_parameters(
2139+
res,
2140+
beta > 0,
2141+
msg="beta > 0",
2142+
)
2143+
21242144

21252145
class Gamma(PositiveContinuous):
21262146
r"""
@@ -2526,6 +2546,16 @@ def logp(value, alpha, beta):
25262546
msg="alpha > 0, beta > 0",
25272547
)
25282548

2549+
def icdf(value, alpha, beta):
2550+
res = beta * (-pt.log(1 - value)) ** (1 / alpha)
2551+
res = check_icdf_value(res, value)
2552+
return check_parameters(
2553+
res,
2554+
alpha > 0,
2555+
beta > 0,
2556+
msg="alpha > 0, beta > 0",
2557+
)
2558+
25292559

25302560
class HalfStudentTRV(RandomVariable):
25312561
name = "halfstudentt"
@@ -3069,6 +3099,20 @@ def logcdf(value, lower, c, upper):
30693099
msg="lower <= c <= upper",
30703100
)
30713101

3102+
def icdf(value, lower, c, upper):
3103+
res = pt.switch(
3104+
pt.lt(value, ((c - lower) / (upper - lower))),
3105+
lower + np.sqrt((upper - lower) * (c - lower) * value),
3106+
upper - np.sqrt((upper - lower) * (upper - c) * (1 - value)),
3107+
)
3108+
res = check_icdf_value(res, value)
3109+
return check_parameters(
3110+
res,
3111+
lower <= c,
3112+
c <= upper,
3113+
msg="lower <= c <= upper",
3114+
)
3115+
30723116

30733117
@_default_transform.register(Triangular)
30743118
def triangular_default_transform(op, rv):
@@ -3157,6 +3201,15 @@ def logcdf(value, mu, beta):
31573201
msg="beta > 0",
31583202
)
31593203

3204+
def icdf(value, mu, beta):
3205+
res = mu - beta * pt.log(-pt.log(value))
3206+
res = check_icdf_value(res, value)
3207+
return check_parameters(
3208+
res,
3209+
beta > 0,
3210+
msg="beta > 0",
3211+
)
3212+
31603213

31613214
class RiceRV(RandomVariable):
31623215
name = "rice"
@@ -3713,6 +3766,15 @@ def logcdf(value, mu, sigma):
37133766
msg="sigma > 0",
37143767
)
37153768

3769+
def icdf(value, mu, sigma):
3770+
res = sigma * -pt.log(2.0 * pt.erfcinv(value) ** 2) + mu
3771+
res = check_icdf_value(res, value)
3772+
return check_parameters(
3773+
res,
3774+
sigma > 0,
3775+
msg="sigma > 0",
3776+
)
3777+
37163778

37173779
class PolyaGammaRV(RandomVariable):
37183780
"""Polya-Gamma random variable."""

pymc/distributions/multivariate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
132132
chol = pt.as_tensor_variable(chol)
133133
if chol.ndim != 2:
134134
raise ValueError("chol must be two dimensional.")
135+
136+
# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
137+
chol.tag.lower_triangular = True
135138
cov = chol.dot(chol.T)
136139

137140
return cov

pymc/logprob/transforms.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
448448

449449
backward_value = op.transform_elemwise.backward(value, *other_inputs)
450450

451-
# Some transformations, like squaring may produce multiple backward values
451+
# Fail if transformation is not injective
452+
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
452453
if isinstance(backward_value, tuple):
453454
raise NotImplementedError
454455

@@ -469,6 +470,11 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
469470
input_icdf = _icdf_helper(measurable_input, value)
470471
icdf = op.transform_elemwise.forward(input_icdf, *other_inputs)
471472

473+
# Fail if transformation is not injective
474+
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
475+
if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple):
476+
raise NotImplementedError
477+
472478
return icdf
473479

474480

@@ -958,8 +964,10 @@ class SimplexTransform(RVTransform):
958964
name = "simplex"
959965

960966
def forward(self, value, *inputs):
967+
value = pt.as_tensor(value)
961968
log_value = pt.log(value)
962-
shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
969+
N = value.shape[-1].astype(value.dtype)
970+
shift = pt.sum(log_value, -1, keepdims=True) / N
963971
return log_value[..., :-1] - shift
964972

965973
def backward(self, value, *inputs):
@@ -968,7 +976,9 @@ def backward(self, value, *inputs):
968976
return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True)
969977

970978
def log_jac_det(self, value, *inputs):
979+
value = pt.as_tensor(value)
971980
N = value.shape[-1] + 1
981+
N = N.astype(value.dtype)
972982
sum_value = pt.sum(value, -1, keepdims=True)
973983
value_sum_expanded = value + sum_value
974984
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)

pymc/math.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,11 +443,17 @@ def expand_packed_triangular(n, packed, lower=True, diagonal_only=False):
443443
elif lower:
444444
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
445445
idxs = np.tril_indices(n)
446-
return pt.set_subtensor(out[idxs], packed)
446+
# tag as lower triangular to enable pytensor rewrites
447+
out = pt.set_subtensor(out[idxs], packed)
448+
out.tag.lower_triangular = True
449+
return out
447450
elif not lower:
448451
out = pt.zeros((n, n), dtype=pytensor.config.floatX)
449452
idxs = np.triu_indices(n)
450-
return pt.set_subtensor(out[idxs], packed)
453+
# tag as upper triangular to enable pytensor rewrites
454+
out = pt.set_subtensor(out[idxs], packed)
455+
out.tag.upper_triangular = True
456+
return out
451457

452458

453459
class BatchedDiag(Op):

pymc/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
hessian,
7676
inputvars,
7777
replace_rvs_by_values,
78+
rewrite_pregrad,
7879
)
7980
from pymc.util import (
8081
UNSET,
@@ -381,6 +382,8 @@ def __init__(
381382
self._extra_vars_shared[var.name] = shared
382383
givens.append((var, shared))
383384

385+
cost = rewrite_pregrad(cost)
386+
384387
if compute_grads:
385388
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
386389
for grad_wrt, var in zip(grads, grad_vars):
@@ -824,6 +827,7 @@ def dlogp(
824827
)
825828

826829
cost = self.logp(jacobian=jacobian)
830+
cost = rewrite_pregrad(cost)
827831
return gradient(cost, value_vars)
828832

829833
def d2logp(
@@ -862,6 +866,7 @@ def d2logp(
862866
)
863867

864868
cost = self.logp(jacobian=jacobian)
869+
cost = rewrite_pregrad(cost)
865870
return hessian(cost, value_vars)
866871

867872
@property

pymc/pytensorf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,3 +1228,10 @@ def constant_fold(
12281228
return tuple(
12291229
folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
12301230
)
1231+
1232+
1233+
def rewrite_pregrad(graph):
1234+
"""Apply simplifying or stabilizing rewrites to graph that are safe to use
1235+
pre-grad.
1236+
"""
1237+
return rewrite_graph(graph, include=("canonicalize", "stabilize"))

tests/distributions/test_continuous.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,12 @@ def test_triangular(self):
207207
lambda value, c, lower, upper: st.triang.logcdf(value, c - lower, lower, upper - lower),
208208
skip_paramdomain_outside_edge_test=True,
209209
)
210+
check_icdf(
211+
pm.Triangular,
212+
{"lower": -Rplusunif, "c": Runif, "upper": Rplusunif},
213+
lambda q, c, lower, upper: st.triang.ppf(q, c - lower, lower, upper - lower),
214+
skip_paramdomain_outside_edge_test=True,
215+
)
210216

211217
# Custom logp/logcdf check for values outside of domain
212218
valid_dist = pm.Triangular.dist(lower=0, upper=1, c=0.9, size=2)
@@ -299,6 +305,11 @@ def test_half_normal(self):
299305
{"sigma": Rplus},
300306
lambda value, sigma: st.halfnorm.logcdf(value, scale=sigma),
301307
)
308+
check_icdf(
309+
pm.HalfNormal,
310+
{"sigma": Rplus},
311+
lambda q, sigma: st.halfnorm.ppf(q, scale=sigma),
312+
)
302313

303314
def test_chisquared_logp(self):
304315
check_logp(
@@ -502,6 +513,21 @@ def test_lognormal(self):
502513
{"mu": R, "sigma": Rplusbig},
503514
lambda value, mu, sigma: st.lognorm.logcdf(value, sigma, 0, np.exp(mu)),
504515
)
516+
check_icdf(
517+
pm.LogNormal,
518+
{"mu": R, "tau": Rplusbig},
519+
lambda q, mu, tau: floatX(st.lognorm.ppf(q, tau**-0.5, 0, np.exp(mu))),
520+
)
521+
# Because we exponentiate the normal quantile function, setting sigma >= 9.5
522+
# return extreme values that results in relative errors above 4 digits
523+
# we circumvent it by keeping it below or equal to 9.
524+
custom_rplusbig = Domain([0, 0.5, 0.9, 0.99, 1, 1.5, 2, 9, np.inf])
525+
check_icdf(
526+
pm.LogNormal,
527+
{"mu": R, "sigma": custom_rplusbig},
528+
lambda q, mu, sigma: floatX(st.lognorm.ppf(q, sigma, 0, np.exp(mu))),
529+
decimal=select_by_precision(float64=4, float32=3),
530+
)
505531

506532
def test_studentt_logp(self):
507533
check_logp(
@@ -567,6 +593,9 @@ def test_half_cauchy(self):
567593
{"beta": Rplusbig},
568594
lambda value, beta: st.halfcauchy.logcdf(value, scale=beta),
569595
)
596+
check_icdf(
597+
pm.HalfCauchy, {"beta": Rplusbig}, lambda q, beta: st.halfcauchy.ppf(q, scale=beta)
598+
)
570599

571600
def test_gamma_logp(self):
572601
check_logp(
@@ -681,6 +710,13 @@ def test_weibull_logcdf(self):
681710
lambda value, alpha, beta: st.exponweib.logcdf(value, 1, alpha, scale=beta),
682711
)
683712

713+
def test_weibull_icdf(self):
714+
check_icdf(
715+
pm.Weibull,
716+
{"alpha": Rplusbig, "beta": Rplusbig},
717+
lambda q, alpha, beta: st.exponweib.ppf(q, 1, alpha, scale=beta),
718+
)
719+
684720
def test_half_studentt(self):
685721
# this is only testing for nu=1 (halfcauchy)
686722
check_logp(
@@ -757,6 +793,11 @@ def test_gumbel(self):
757793
{"mu": R, "beta": Rplusbig},
758794
lambda value, mu, beta: st.gumbel_r.logcdf(value, loc=mu, scale=beta),
759795
)
796+
check_icdf(
797+
pm.Gumbel,
798+
{"mu": R, "beta": Rplusbig},
799+
lambda q, mu, beta: st.gumbel_r.ppf(q, loc=mu, scale=beta),
800+
)
760801

761802
def test_logistic(self):
762803
check_logp(
@@ -840,6 +881,13 @@ def test_moyal_logcdf(self):
840881
if pytensor.config.floatX == "float32":
841882
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
842883

884+
def test_moyal_icdf(self):
885+
check_icdf(
886+
pm.Moyal,
887+
{"mu": R, "sigma": Rplusbig},
888+
lambda q, mu, sigma: floatX(st.moyal.ppf(q, mu, sigma)),
889+
)
890+
843891
def test_interpolated(self):
844892
for mu in R.vals:
845893
for sigma in Rplus.vals:

0 commit comments

Comments
 (0)