Skip to content

Commit 598dd9d

Browse files
Add dims to *Ordered probs variables when appropriate (#5084)
* Add dims to *Ordered probs variables when appropriate * Make black happy * For real about black this time
1 parent e03f5bf commit 598dd9d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pymc/distributions/discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,7 +1745,7 @@ class OrderedLogistic:
17451745
def __new__(cls, name, *args, compute_p=True, **kwargs):
17461746
out_rv = _OrderedLogistic(name, *args, **kwargs)
17471747
if compute_p:
1748-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3])
1748+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3], dims=kwargs.get("dims"))
17491749
return out_rv
17501750

17511751
@classmethod
@@ -1856,7 +1856,7 @@ class OrderedProbit:
18561856
def __new__(cls, name, *args, compute_p=True, **kwargs):
18571857
out_rv = _OrderedProbit(name, *args, **kwargs)
18581858
if compute_p:
1859-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3])
1859+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3], dims=kwargs.get("dims"))
18601860
return out_rv
18611861

18621862
@classmethod

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ class OrderedMultinomial:
752752
def __new__(cls, name, *args, compute_p=True, **kwargs):
753753
out_rv = _OrderedMultinomial(name, *args, **kwargs)
754754
if compute_p:
755-
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4])
755+
pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4], dims=kwargs.get("dims"))
756756
return out_rv
757757

758758
@classmethod

0 commit comments

Comments
 (0)