13
13
# limitations under the License.
14
14
15
15
16
+ from typing import Union
17
+
16
18
import aesara
17
19
import aesara .tensor as at
18
20
import numpy as np
@@ -139,10 +141,18 @@ def test_simplex_accuracy():
139
141
140
142
141
143
def test_sum_to_1 ():
142
- check_vector_transform (tr .sum_to_1 , Simplex (2 ))
143
- check_vector_transform (tr .sum_to_1 , Simplex (4 ))
144
+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (2 ))
145
+ check_vector_transform (tr .univariate_sum_to_1 , Simplex (4 ))
144
146
145
- check_jacobian_det (tr .sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ])
147
+ with pytest .raises (ValueError , match = r"\(ndim_supp\) must not exceed 1" ):
148
+ tr .SumTo1 (2 )
149
+
150
+ check_jacobian_det (
151
+ tr .univariate_sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
152
+ )
153
+ check_jacobian_det (
154
+ tr .multivariate_sum_to_1 , Vector (Unit , 2 ), at .dvector , np .array ([0 , 0 ]), lambda x : x [:- 1 ]
155
+ )
146
156
147
157
148
158
def test_log ():
@@ -241,28 +251,36 @@ def test_circular():
241
251
242
252
243
253
def test_ordered ():
244
- check_vector_transform (tr .ordered , SortedVector (6 ))
254
+ check_vector_transform (tr .univariate_ordered , SortedVector (6 ))
255
+
256
+ with pytest .raises (ValueError , match = r"\(ndim_supp\) must not exceed 1" ):
257
+ tr .Ordered (2 )
245
258
246
- check_jacobian_det (tr .ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False )
259
+ check_jacobian_det (
260
+ tr .univariate_ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False
261
+ )
262
+ check_jacobian_det (
263
+ tr .multivariate_ordered , Vector (R , 2 ), at .dvector , np .array ([0 , 0 ]), elemwise = False
264
+ )
247
265
248
- vals = get_values (tr .ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
266
+ vals = get_values (tr .univariate_ordered , Vector (R , 3 ), at .dvector , np .zeros (3 ))
249
267
close_to_logical (np .diff (vals ) >= 0 , True , tol )
250
268
251
269
252
270
def test_chain_values ():
253
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
271
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
254
272
vals = get_values (chain_tranf , Vector (R , 5 ), at .dvector , np .zeros (5 ))
255
273
close_to_logical (np .diff (vals ) >= 0 , True , tol )
256
274
257
275
258
276
def test_chain_vector_transform ():
259
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
277
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
260
278
check_vector_transform (chain_tranf , UnitSortedVector (3 ))
261
279
262
280
263
281
@pytest .mark .xfail (reason = "Fails due to precision issue. Values just close to expected." )
264
282
def test_chain_jacob_det ():
265
- chain_tranf = tr .Chain ([tr .logodds , tr .ordered ])
283
+ chain_tranf = tr .Chain ([tr .logodds , tr .univariate_ordered ])
266
284
check_jacobian_det (chain_tranf , Vector (R , 4 ), at .dvector , np .zeros (4 ), elemwise = False )
267
285
268
286
@@ -327,7 +345,14 @@ def check_vectortransform_elementwise_logp(self, model):
327
345
jacob_det = transform .log_jac_det (test_array_transf , * x .owner .inputs )
328
346
# Original distribution is univariate
329
347
if x .owner .op .ndim_supp == 0 :
330
- assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
348
+ tr_steps = getattr (transform , "transform_list" , [transform ])
349
+ transform_keeps_dim = any (
350
+ [isinstance (ts , Union [tr .SumTo1 , tr .Ordered ]) for ts in tr_steps ]
351
+ )
352
+ if transform_keeps_dim :
353
+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == jacob_det .ndim
354
+ else :
355
+ assert model .logp (x , sum = False )[0 ].ndim == x .ndim == (jacob_det .ndim + 1 )
331
356
# Original distribution is multivariate
332
357
else :
333
358
assert model .logp (x , sum = False )[0 ].ndim == (x .ndim - 1 ) == jacob_det .ndim
@@ -449,7 +474,7 @@ def test_normal_ordered(self):
449
474
{"mu" : 0.0 , "sigma" : 1.0 },
450
475
size = 3 ,
451
476
initval = np .asarray ([- 1.0 , 1.0 , 4.0 ]),
452
- transform = tr .ordered ,
477
+ transform = tr .univariate_ordered ,
453
478
)
454
479
self .check_vectortransform_elementwise_logp (model )
455
480
@@ -467,7 +492,7 @@ def test_half_normal_ordered(self, sigma, size):
467
492
{"sigma" : sigma },
468
493
size = size ,
469
494
initval = initval ,
470
- transform = tr .Chain ([tr .log , tr .ordered ]),
495
+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
471
496
)
472
497
self .check_vectortransform_elementwise_logp (model )
473
498
@@ -479,7 +504,7 @@ def test_exponential_ordered(self, lam, size):
479
504
{"lam" : lam },
480
505
size = size ,
481
506
initval = initval ,
482
- transform = tr .Chain ([tr .log , tr .ordered ]),
507
+ transform = tr .Chain ([tr .log , tr .univariate_ordered ]),
483
508
)
484
509
self .check_vectortransform_elementwise_logp (model )
485
510
@@ -501,7 +526,7 @@ def test_beta_ordered(self, a, b, size):
501
526
{"alpha" : a , "beta" : b },
502
527
size = size ,
503
528
initval = initval ,
504
- transform = tr .Chain ([tr .logodds , tr .ordered ]),
529
+ transform = tr .Chain ([tr .logodds , tr .univariate_ordered ]),
505
530
)
506
531
self .check_vectortransform_elementwise_logp (model )
507
532
@@ -524,7 +549,7 @@ def transform_params(*inputs):
524
549
{"lower" : lower , "upper" : upper },
525
550
size = size ,
526
551
initval = initval ,
527
- transform = tr .Chain ([interval , tr .ordered ]),
552
+ transform = tr .Chain ([interval , tr .univariate_ordered ]),
528
553
)
529
554
self .check_vectortransform_elementwise_logp (model )
530
555
@@ -536,7 +561,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
536
561
{"mu" : mu , "kappa" : kappa },
537
562
size = size ,
538
563
initval = initval ,
539
- transform = tr .Chain ([tr .circular , tr .ordered ]),
564
+ transform = tr .Chain ([tr .circular , tr .univariate_ordered ]),
540
565
)
541
566
self .check_vectortransform_elementwise_logp (model )
542
567
@@ -545,7 +570,7 @@ def test_vonmises_ordered(self, mu, kappa, size):
545
570
[
546
571
(0.0 , 1.0 , (2 ,), tr .simplex ),
547
572
(0.5 , 5.5 , (2 , 3 ), tr .simplex ),
548
- (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .sum_to_1 , tr .logodds ])),
573
+ (np .zeros (3 ), np .ones (3 ), (4 , 3 ), tr .Chain ([tr .univariate_sum_to_1 , tr .logodds ])),
549
574
],
550
575
)
551
576
def test_uniform_other (self , lower , upper , size , transform ):
@@ -569,7 +594,11 @@ def test_uniform_other(self, lower, upper, size, transform):
569
594
def test_mvnormal_ordered (self , mu , cov , size , shape ):
570
595
initval = np .sort (np .random .randn (* shape ))
571
596
model = self .build_model (
572
- pm .MvNormal , {"mu" : mu , "cov" : cov }, size = size , initval = initval , transform = tr .ordered
597
+ pm .MvNormal ,
598
+ {"mu" : mu , "cov" : cov },
599
+ size = size ,
600
+ initval = initval ,
601
+ transform = tr .multivariate_ordered ,
573
602
)
574
603
self .check_vectortransform_elementwise_logp (model )
575
604
@@ -598,3 +627,95 @@ def test_discrete_trafo():
598
627
with pytest .raises (ValueError ) as err :
599
628
pm .Binomial ("a" , n = 5 , p = 0.5 , transform = "log" )
600
629
err .match ("Transformations for discrete distributions" )
630
+
631
+
632
+ def test_2d_univariate_ordered ():
633
+ with pm .Model () as model :
634
+ x_1d = pm .Normal (
635
+ "x_1d" ,
636
+ mu = [- 3 , - 1 , 1 , 2 ],
637
+ sigma = 1 ,
638
+ size = (4 ,),
639
+ transform = tr .univariate_ordered ,
640
+ )
641
+ x_2d = pm .Normal (
642
+ "x_2d" ,
643
+ mu = [- 3 , - 1 , 1 , 2 ],
644
+ sigma = 1 ,
645
+ size = (10 , 4 ),
646
+ transform = tr .univariate_ordered ,
647
+ )
648
+
649
+ log_p = model .compile_logp (sum = False )(
650
+ {"x_1d_ordered__" : np .zeros ((4 ,)), "x_2d_ordered__" : np .zeros ((10 , 4 ))}
651
+ )
652
+ np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
653
+
654
+
655
+ def test_2d_multivariate_ordered ():
656
+ with pm .Model () as model :
657
+ x_1d = pm .MvNormal (
658
+ "x_1d" ,
659
+ mu = [- 1 , 1 ],
660
+ cov = np .eye (2 ),
661
+ initval = [- 1 , 1 ],
662
+ transform = tr .multivariate_ordered ,
663
+ )
664
+ x_2d = pm .MvNormal (
665
+ "x_2d" ,
666
+ mu = [- 1 , 1 ],
667
+ cov = np .eye (2 ),
668
+ size = 2 ,
669
+ initval = [[- 1 , 1 ], [- 1 , 1 ]],
670
+ transform = tr .multivariate_ordered ,
671
+ )
672
+
673
+ log_p = model .compile_logp (sum = False )(
674
+ {"x_1d_ordered__" : np .zeros ((2 ,)), "x_2d_ordered__" : np .zeros ((2 , 2 ))}
675
+ )
676
+ np .testing .assert_allclose (log_p [0 ], log_p [1 ])
677
+
678
+
679
+ def test_2d_univariate_sum_to_1 ():
680
+ with pm .Model () as model :
681
+ x_1d = pm .Normal (
682
+ "x_1d" ,
683
+ mu = [- 3 , - 1 , 1 , 2 ],
684
+ sigma = 1 ,
685
+ size = (4 ,),
686
+ transform = tr .univariate_sum_to_1 ,
687
+ )
688
+ x_2d = pm .Normal (
689
+ "x_2d" ,
690
+ mu = [- 3 , - 1 , 1 , 2 ],
691
+ sigma = 1 ,
692
+ size = (10 , 4 ),
693
+ transform = tr .univariate_sum_to_1 ,
694
+ )
695
+
696
+ log_p = model .compile_logp (sum = False )(
697
+ {"x_1d_sumto1__" : np .zeros (3 ), "x_2d_sumto1__" : np .zeros ((10 , 3 ))}
698
+ )
699
+ np .testing .assert_allclose (np .tile (log_p [0 ], (10 , 1 )), log_p [1 ])
700
+
701
+
702
+ def test_2d_multivariate_sum_to_1 ():
703
+ with pm .Model () as model :
704
+ x_1d = pm .MvNormal (
705
+ "x_1d" ,
706
+ mu = [- 1 , 1 ],
707
+ cov = np .eye (2 ),
708
+ transform = tr .multivariate_sum_to_1 ,
709
+ )
710
+ x_2d = pm .MvNormal (
711
+ "x_2d" ,
712
+ mu = [- 1 , 1 ],
713
+ cov = np .eye (2 ),
714
+ size = 2 ,
715
+ transform = tr .multivariate_sum_to_1 ,
716
+ )
717
+
718
+ log_p = model .compile_logp (sum = False )(
719
+ {"x_1d_sumto1__" : np .zeros (1 ), "x_2d_sumto1__" : np .zeros ((2 , 1 ))}
720
+ )
721
+ np .testing .assert_allclose (log_p [0 ], log_p [1 ])
0 commit comments