Skip to content

Commit 1b7b4bb

Browse files
AustinRochfordtwiecki
authored andcommitted
Don't broadcast in alltrue (#1452)
1 parent 4a21624 commit 1b7b4bb

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

pymc3/distributions/dist_math.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ def bound(logp, *conditions):
2929

3030

3131
def alltrue(vals):
32-
ret = 1
33-
for c in vals:
34-
ret = ret * (1 * c)
35-
return ret
32+
return tt.all([tt.all(1 * val) for val in vals])
3633

3734

3835
def logpow(x, m):

pymc3/tests/test_dist_math.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import theano.tensor as tt
3+
4+
from ..distributions.dist_math import alltrue
5+
6+
7+
def test_alltrue():
8+
assert alltrue([]).eval()
9+
assert alltrue([True]).eval()
10+
assert alltrue([tt.ones(10)]).eval()
11+
assert alltrue([tt.ones(10),
12+
5 * tt.ones(101)]).eval()
13+
assert alltrue([np.ones(10),
14+
5 * tt.ones(101)]).eval()
15+
assert alltrue([np.ones(10),
16+
True,
17+
5 * tt.ones(101)]).eval()
18+
assert alltrue([np.array([1, 2, 3]),
19+
True,
20+
5 * tt.ones(101)]).eval()
21+
22+
assert not alltrue([False]).eval()
23+
assert not alltrue([tt.zeros(10)]).eval()
24+
assert not alltrue([True,
25+
False]).eval()
26+
assert not alltrue([np.array([0, -1]),
27+
tt.ones(60)]).eval()
28+
assert not alltrue([np.ones(10),
29+
False,
30+
5 * tt.ones(101)]).eval()
31+
32+
33+
def test_alltrue_shape():
34+
vals = [True, tt.ones(10), tt.zeros(5)]
35+
36+
assert alltrue(vals).eval().shape == ()

0 commit comments

Comments
 (0)