Skip to content

Commit 9e92b74

Browse files
Fix Categorical logp implementation
1 parent e875889 commit 9e92b74

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pymc3/distributions/discrete.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,7 @@ def dist(cls, p, **kwargs):
13541354

13551355

13561356
@_logp.register(CategoricalRV)
1357-
def categorical_logp(op, value, p, upper):
1357+
def categorical_logp(op, value, p):
13581358
r"""
13591359
Calculate log-probability of Categorical distribution at specified value.
13601360
@@ -1365,8 +1365,9 @@ def categorical_logp(op, value, p, upper):
13651365
values are desired the values must be provided in a numpy array or `TensorVariable`
13661366
13671367
"""
1368+
k = aet.shape(p)[-1]
1369+
p_ = p
13681370
p = p_ / aet.sum(p_, axis=-1, keepdims=True)
1369-
k = aet.shape(p_)[-1]
13701371
value_clip = aet.clip(value, 0, k - 1)
13711372

13721373
if p.ndim > 1:

0 commit comments

Comments
 (0)