Skip to content

Commit 835f533

Browse files
Ricardo's suggestions
1 parent 1f00345 commit 835f533

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,8 +2063,6 @@ def make_node(self, rng, size, dtype, mu, W, alpha, tau):
20632063
tau = pt.as_tensor_variable(floatX(tau))
20642064

20652065
alpha = pt.as_tensor_variable(floatX(alpha))
2066-
msg = "the domain of alpha is: -1 < alpha < 1."
2067-
alpha = Assert(msg)(alpha, pt.lt(alpha, 1) and pt.gt(alpha, -1))
20682066

20692067
return super().make_node(rng, size, dtype, mu, W, alpha, tau)
20702068

@@ -2080,6 +2078,9 @@ def rng_fn(cls, rng: np.random.RandomState, mu, W, alpha, tau, size):
20802078
Journal of the Royal Statistical Society Series B, Royal Statistical Society,
20812079
vol. 63(2), pages 325-338. DOI: 10.1111/1467-9868.00288
20822080
"""
2081+
if np.all(alpha >= 1) or np.all(alpha <= -1):
2082+
raise ValueError("the domain of alpha is: -1 < alpha < 1")
2083+
20832084
if not scipy.sparse.issparse(W):
20842085
W = scipy.sparse.csr_matrix(W)
20852086
s = np.asarray(W.sum(axis=0))[0]

0 commit comments

Comments
 (0)