Skip to content

Commit 58dfb35

Browse files
Fix ordering Transformation for batched dimensions (#6255)
Co-authored-by: Purna Chandra Mansingh <[email protected]>
1 parent eb16ce6 commit 58dfb35

File tree

2 files changed

+197
-24
lines changed

2 files changed

+197
-24
lines changed

pymc/distributions/transforms.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@
3838
"Interval",
3939
"log_exp_m1",
4040
"ordered",
41+
"univariate_ordered",
42+
"multivariate_ordered",
4143
"log",
4244
"sum_to_1",
45+
"univariate_sum_to_1",
46+
"multivariate_sum_to_1",
4347
"circular",
4448
"CholeskyCovPacked",
4549
"Chain",
@@ -74,6 +78,14 @@ def log_jac_det(self, value, *inputs):
7478
class Ordered(RVTransform):
7579
name = "ordered"
7680

81+
def __init__(self, ndim_supp=0):
82+
if ndim_supp > 1:
83+
raise ValueError(
84+
f"For Ordered transformation number of core dimensions"
85+
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
86+
)
87+
self.ndim_supp = ndim_supp
88+
7789
def backward(self, value, *inputs):
7890
x = at.zeros(value.shape)
7991
x = at.inc_subtensor(x[..., 0], value[..., 0])
@@ -87,7 +99,10 @@ def forward(self, value, *inputs):
8799
return y
88100

89101
def log_jac_det(self, value, *inputs):
90-
return at.sum(value[..., 1:], axis=-1)
102+
if self.ndim_supp == 0:
103+
return at.sum(value[..., 1:], axis=-1, keepdims=True)
104+
else:
105+
return at.sum(value[..., 1:], axis=-1)
91106

92107

93108
class SumTo1(RVTransform):
@@ -98,6 +113,14 @@ class SumTo1(RVTransform):
98113

99114
name = "sumto1"
100115

116+
def __init__(self, ndim_supp=0):
117+
if ndim_supp > 1:
118+
raise ValueError(
119+
f"For SumTo1 transformation number of core dimensions"
120+
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
121+
)
122+
self.ndim_supp = ndim_supp
123+
101124
def backward(self, value, *inputs):
102125
remaining = 1 - at.sum(value[..., :], axis=-1, keepdims=True)
103126
return at.concatenate([value[..., :], remaining], axis=-1)
@@ -107,7 +130,10 @@ def forward(self, value, *inputs):
107130

108131
def log_jac_det(self, value, *inputs):
109132
y = at.zeros(value.shape)
110-
return at.sum(y, axis=-1)
133+
if self.ndim_supp == 0:
134+
return at.sum(y, axis=-1, keepdims=True)
135+
else:
136+
return at.sum(y, axis=-1)
111137

112138

113139
class CholeskyCovPacked(RVTransform):
@@ -330,20 +356,46 @@ def extend_axis_rev(array, axis):
330356
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
331357
for use in the ``transform`` argument of a random variable."""
332358

333-
ordered = Ordered()
359+
univariate_ordered = Ordered(ndim_supp=0)
360+
univariate_ordered.__doc__ = """
361+
Instantiation of :class:`pymc.distributions.transforms.Ordered`
362+
for use in the ``transform`` argument of a univariate random variable."""
363+
364+
multivariate_ordered = Ordered(ndim_supp=1)
365+
multivariate_ordered.__doc__ = """
366+
Instantiation of :class:`pymc.distributions.transforms.Ordered`
367+
for use in the ``transform`` argument of a multivariate random variable."""
368+
369+
# backwards compatibility
370+
ordered = Ordered(ndim_supp=1)
334371
ordered.__doc__ = """
335372
Instantiation of :class:`pymc.distributions.transforms.Ordered`
336-
for use in the ``transform`` argument of a random variable."""
373+
for use in the ``transform`` argument of a random variable.
374+
This instantiation is for backwards compatibility only.
375+
Please use `univariate_ordererd` or `multivariate_ordered` instead."""
337376

338377
log = LogTransform()
339378
log.__doc__ = """
340379
Instantiation of :class:`aeppl.transforms.LogTransform`
341380
for use in the ``transform`` argument of a random variable."""
342381

343-
sum_to_1 = SumTo1()
382+
univariate_sum_to_1 = SumTo1(ndim_supp=0)
383+
univariate_sum_to_1.__doc__ = """
384+
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
385+
for use in the ``transform`` argument of a univariate random variable."""
386+
387+
multivariate_sum_to_1 = SumTo1(ndim_supp=1)
388+
multivariate_sum_to_1.__doc__ = """
389+
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
390+
for use in the ``transform`` argument of a multivariate random variable."""
391+
392+
# backwards compatibility
393+
sum_to_1 = SumTo1(ndim_supp=1)
344394
sum_to_1.__doc__ = """
345395
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
346-
for use in the ``transform`` argument of a random variable."""
396+
for use in the ``transform`` argument of a random variable.
397+
This instantiation is for backwards compatibility only.
398+
Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead."""
347399

348400
circular = CircularTransform()
349401
circular.__doc__ = """

pymc/tests/distributions/test_transform.py

Lines changed: 139 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from typing import Union
17+
1618
import aesara
1719
import aesara.tensor as at
1820
import numpy as np
@@ -139,10 +141,18 @@ def test_simplex_accuracy():
139141

140142

141143
def test_sum_to_1():
142-
check_vector_transform(tr.sum_to_1, Simplex(2))
143-
check_vector_transform(tr.sum_to_1, Simplex(4))
144+
check_vector_transform(tr.univariate_sum_to_1, Simplex(2))
145+
check_vector_transform(tr.univariate_sum_to_1, Simplex(4))
144146

145-
check_jacobian_det(tr.sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1])
147+
with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"):
148+
tr.SumTo1(2)
149+
150+
check_jacobian_det(
151+
tr.univariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
152+
)
153+
check_jacobian_det(
154+
tr.multivariate_sum_to_1, Vector(Unit, 2), at.dvector, np.array([0, 0]), lambda x: x[:-1]
155+
)
146156

147157

148158
def test_log():
@@ -241,28 +251,36 @@ def test_circular():
241251

242252

243253
def test_ordered():
244-
check_vector_transform(tr.ordered, SortedVector(6))
254+
check_vector_transform(tr.univariate_ordered, SortedVector(6))
255+
256+
with pytest.raises(ValueError, match=r"\(ndim_supp\) must not exceed 1"):
257+
tr.Ordered(2)
245258

246-
check_jacobian_det(tr.ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False)
259+
check_jacobian_det(
260+
tr.univariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
261+
)
262+
check_jacobian_det(
263+
tr.multivariate_ordered, Vector(R, 2), at.dvector, np.array([0, 0]), elemwise=False
264+
)
247265

248-
vals = get_values(tr.ordered, Vector(R, 3), at.dvector, np.zeros(3))
266+
vals = get_values(tr.univariate_ordered, Vector(R, 3), at.dvector, np.zeros(3))
249267
close_to_logical(np.diff(vals) >= 0, True, tol)
250268

251269

252270
def test_chain_values():
253-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
271+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
254272
vals = get_values(chain_tranf, Vector(R, 5), at.dvector, np.zeros(5))
255273
close_to_logical(np.diff(vals) >= 0, True, tol)
256274

257275

258276
def test_chain_vector_transform():
259-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
277+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
260278
check_vector_transform(chain_tranf, UnitSortedVector(3))
261279

262280

263281
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
264282
def test_chain_jacob_det():
265-
chain_tranf = tr.Chain([tr.logodds, tr.ordered])
283+
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
266284
check_jacobian_det(chain_tranf, Vector(R, 4), at.dvector, np.zeros(4), elemwise=False)
267285

268286

@@ -327,7 +345,14 @@ def check_vectortransform_elementwise_logp(self, model):
327345
jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs)
328346
# Original distribution is univariate
329347
if x.owner.op.ndim_supp == 0:
330-
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
348+
tr_steps = getattr(transform, "transform_list", [transform])
349+
transform_keeps_dim = any(
350+
[isinstance(ts, Union[tr.SumTo1, tr.Ordered]) for ts in tr_steps]
351+
)
352+
if transform_keeps_dim:
353+
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
354+
else:
355+
assert model.logp(x, sum=False)[0].ndim == x.ndim == (jacob_det.ndim + 1)
331356
# Original distribution is multivariate
332357
else:
333358
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
@@ -449,7 +474,7 @@ def test_normal_ordered(self):
449474
{"mu": 0.0, "sigma": 1.0},
450475
size=3,
451476
initval=np.asarray([-1.0, 1.0, 4.0]),
452-
transform=tr.ordered,
477+
transform=tr.univariate_ordered,
453478
)
454479
self.check_vectortransform_elementwise_logp(model)
455480

@@ -467,7 +492,7 @@ def test_half_normal_ordered(self, sigma, size):
467492
{"sigma": sigma},
468493
size=size,
469494
initval=initval,
470-
transform=tr.Chain([tr.log, tr.ordered]),
495+
transform=tr.Chain([tr.log, tr.univariate_ordered]),
471496
)
472497
self.check_vectortransform_elementwise_logp(model)
473498

@@ -479,7 +504,7 @@ def test_exponential_ordered(self, lam, size):
479504
{"lam": lam},
480505
size=size,
481506
initval=initval,
482-
transform=tr.Chain([tr.log, tr.ordered]),
507+
transform=tr.Chain([tr.log, tr.univariate_ordered]),
483508
)
484509
self.check_vectortransform_elementwise_logp(model)
485510

@@ -501,7 +526,7 @@ def test_beta_ordered(self, a, b, size):
501526
{"alpha": a, "beta": b},
502527
size=size,
503528
initval=initval,
504-
transform=tr.Chain([tr.logodds, tr.ordered]),
529+
transform=tr.Chain([tr.logodds, tr.univariate_ordered]),
505530
)
506531
self.check_vectortransform_elementwise_logp(model)
507532

@@ -524,7 +549,7 @@ def transform_params(*inputs):
524549
{"lower": lower, "upper": upper},
525550
size=size,
526551
initval=initval,
527-
transform=tr.Chain([interval, tr.ordered]),
552+
transform=tr.Chain([interval, tr.univariate_ordered]),
528553
)
529554
self.check_vectortransform_elementwise_logp(model)
530555

@@ -536,7 +561,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
536561
{"mu": mu, "kappa": kappa},
537562
size=size,
538563
initval=initval,
539-
transform=tr.Chain([tr.circular, tr.ordered]),
564+
transform=tr.Chain([tr.circular, tr.univariate_ordered]),
540565
)
541566
self.check_vectortransform_elementwise_logp(model)
542567

@@ -545,7 +570,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
545570
[
546571
(0.0, 1.0, (2,), tr.simplex),
547572
(0.5, 5.5, (2, 3), tr.simplex),
548-
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.sum_to_1, tr.logodds])),
573+
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
549574
],
550575
)
551576
def test_uniform_other(self, lower, upper, size, transform):
@@ -569,7 +594,11 @@ def test_uniform_other(self, lower, upper, size, transform):
569594
def test_mvnormal_ordered(self, mu, cov, size, shape):
570595
initval = np.sort(np.random.randn(*shape))
571596
model = self.build_model(
572-
pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered
597+
pm.MvNormal,
598+
{"mu": mu, "cov": cov},
599+
size=size,
600+
initval=initval,
601+
transform=tr.multivariate_ordered,
573602
)
574603
self.check_vectortransform_elementwise_logp(model)
575604

@@ -598,3 +627,95 @@ def test_discrete_trafo():
598627
with pytest.raises(ValueError) as err:
599628
pm.Binomial("a", n=5, p=0.5, transform="log")
600629
err.match("Transformations for discrete distributions")
630+
631+
632+
def test_2d_univariate_ordered():
633+
with pm.Model() as model:
634+
x_1d = pm.Normal(
635+
"x_1d",
636+
mu=[-3, -1, 1, 2],
637+
sigma=1,
638+
size=(4,),
639+
transform=tr.univariate_ordered,
640+
)
641+
x_2d = pm.Normal(
642+
"x_2d",
643+
mu=[-3, -1, 1, 2],
644+
sigma=1,
645+
size=(10, 4),
646+
transform=tr.univariate_ordered,
647+
)
648+
649+
log_p = model.compile_logp(sum=False)(
650+
{"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))}
651+
)
652+
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])
653+
654+
655+
def test_2d_multivariate_ordered():
656+
with pm.Model() as model:
657+
x_1d = pm.MvNormal(
658+
"x_1d",
659+
mu=[-1, 1],
660+
cov=np.eye(2),
661+
initval=[-1, 1],
662+
transform=tr.multivariate_ordered,
663+
)
664+
x_2d = pm.MvNormal(
665+
"x_2d",
666+
mu=[-1, 1],
667+
cov=np.eye(2),
668+
size=2,
669+
initval=[[-1, 1], [-1, 1]],
670+
transform=tr.multivariate_ordered,
671+
)
672+
673+
log_p = model.compile_logp(sum=False)(
674+
{"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))}
675+
)
676+
np.testing.assert_allclose(log_p[0], log_p[1])
677+
678+
679+
def test_2d_univariate_sum_to_1():
680+
with pm.Model() as model:
681+
x_1d = pm.Normal(
682+
"x_1d",
683+
mu=[-3, -1, 1, 2],
684+
sigma=1,
685+
size=(4,),
686+
transform=tr.univariate_sum_to_1,
687+
)
688+
x_2d = pm.Normal(
689+
"x_2d",
690+
mu=[-3, -1, 1, 2],
691+
sigma=1,
692+
size=(10, 4),
693+
transform=tr.univariate_sum_to_1,
694+
)
695+
696+
log_p = model.compile_logp(sum=False)(
697+
{"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))}
698+
)
699+
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])
700+
701+
702+
def test_2d_multivariate_sum_to_1():
703+
with pm.Model() as model:
704+
x_1d = pm.MvNormal(
705+
"x_1d",
706+
mu=[-1, 1],
707+
cov=np.eye(2),
708+
transform=tr.multivariate_sum_to_1,
709+
)
710+
x_2d = pm.MvNormal(
711+
"x_2d",
712+
mu=[-1, 1],
713+
cov=np.eye(2),
714+
size=2,
715+
transform=tr.multivariate_sum_to_1,
716+
)
717+
718+
log_p = model.compile_logp(sum=False)(
719+
{"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))}
720+
)
721+
np.testing.assert_allclose(log_p[0], log_p[1])

0 commit comments

Comments
 (0)