Skip to content

Commit 613ce57

Browse files
Add dims to *Ordered probs variables when appropriate
1 parent e03f5bf commit 613ce57

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pymc/distributions/discrete.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,8 @@ class OrderedLogistic:
17451745
def __new__(cls, name, *args, compute_p=True, **kwargs):
17461746
out_rv = _OrderedLogistic(name, *args, **kwargs)
17471747
if compute_p:
1748-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3])
1748+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3],
1749+
dims=kwargs.get('dims'))
17491750
return out_rv
17501751

17511752
@classmethod
@@ -1856,7 +1857,8 @@ class OrderedProbit:
18561857
def __new__(cls, name, *args, compute_p=True, **kwargs):
18571858
out_rv = _OrderedProbit(name, *args, **kwargs)
18581859
if compute_p:
1859-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3])
1860+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3],
1861+
dims=kwargs.get('dims'))
18601862
return out_rv
18611863

18621864
@classmethod

pymc/distributions/multivariate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,8 @@ class OrderedMultinomial:
752752
def __new__(cls, name, *args, compute_p=True, **kwargs):
753753
out_rv = _OrderedMultinomial(name, *args, **kwargs)
754754
if compute_p:
755-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4])
755+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4],
756+
dims=kwargs.get('dims'))
756757
return out_rv
757758

758759
@classmethod

0 commit comments

Comments
 (0)